diff --git a/.gitattributes b/.gitattributes index 15caeca0559652cfd179862d561abade06f2740b..01af77dc76a68de64c71a17a166681c6e15c414d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -38,3 +38,13 @@ docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png filter=lfs diff=lfs merge=lfs -text docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png filter=lfs diff=lfs merge=lfs -text +old/dataset_10k_train.jsonl filter=lfs diff=lfs merge=lfs -text +old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250720_214625-3kgefhnp/run-3kgefhnp.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250722_000857-dio4c8kj/run-dio4c8kj.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250720_155533-1r0qjmiz/run-1r0qjmiz.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250720_231916-zbtazovk/run-zbtazovk.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250624_115955-iye05c18/run-iye05c18.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text +wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text +seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text diff --git a/docs/transformers/build/lib/transformers/models/chameleon/modeling_chameleon.py b/docs/transformers/build/lib/transformers/models/chameleon/modeling_chameleon.py new file mode 100644 index 0000000000000000000000000000000000000000..1c83ddea5a7e1a746442ff55d47340c0558fc77a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chameleon/modeling_chameleon.py @@ -0,0 +1,1673 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Chameleon model.""" + +import math +from functools import cached_property +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ChameleonConfig" +_CHECKPOINT_FOR_DOC = "meta/chameleon-7b" +_EXPECTED_OUTPUT_SHAPE = [1, 7, 4096] +_SEQ_CLASS_EXPECTED_LOSS = 1.03 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon +class ChameleonRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ChameleonRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon +# TODO(joao): add me back asap :) +class ChameleonRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): + """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): + """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon +class ChameleonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + # Ignore copy + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class ChameleonLayerNorm(nn.LayerNorm): + """ + LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta + from each shard separately to each head, instead of reducing. We can apply each head's own + gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed + in the last dimension. This module applies gamma/beta manually to fulfill this requirement. + """ + + def __init__(self, hidden_size, *args, **kwargs): + super().__init__(hidden_size, *args, **kwargs) + self.normalized_shape = (hidden_size[-1],) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5) + hidden_states = hidden_states * self.weight + self.bias + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ChameleonAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.model_parallel_size = config.model_parallel_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) + self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim)) + self._init_rope() + + # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon + # TODO(joao): add me back asap :) + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = ChameleonRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = ChameleonLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon +# TODO(joao): add me back asap :) +class ChameleonFlashAttention2(ChameleonAttention): + """ + Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (ChameleonRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class ChameleonSdpaAttention(ChameleonAttention): + """ + Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from ChameleonAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +CHAMELEON_ATTENTION_CLASSES = { + "eager": ChameleonAttention, + "flash_attention_2": ChameleonFlashAttention2, + "sdpa": ChameleonSdpaAttention, +} + + +# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON +# TODO(joao): add me back asap :) +class ChameleonDecoderLayer(nn.Module): + def __init__(self, config: ChameleonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ChameleonMLP(config) + self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ChameleonSwinDecoderLayer(nn.Module): + def __init__(self, config: ChameleonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ChameleonMLP(config) + self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ChameleonVQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + ) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach()) ** 2 + ) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + def __init__( + self, + config, + in_channels, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels) ** (-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +class ChameleonVQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.LongTensor): + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1], + ) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block](hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) + + return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +CHAMELEON_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare chameleon Model outputting raw hidden-states without any specific head on top.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonPreTrainedModel(PreTrainedModel): + config_class = ChameleonConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_cache_class = True + _supports_static_cache = True + _supports_param_buffer_assignment = False + + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ChameleonRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +CHAMELEON_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonVQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + CHAMELEON_VQ_START_DOCSTRING, +) +class ChameleonVQVAE(ChameleonPreTrainedModel): + config_class = ChameleonVQVAEConfig + _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__(config) + + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode(self, pixel_values: torch.LongTensor): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +CHAMELEON_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ChameleonImageProcessor.__call__`] for details. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Should always be a [`~cache_utils.Cache`] instance and the model will output the same cache instance. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare chameleon Model outputting raw hidden-states without any specific head on top.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonModel(ChameleonPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ChameleonDecoderLayer`] + + Args: + config: ChameleonConfig + """ + + def __init__(self, config: ChameleonConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer + self.layers = nn.ModuleList( + [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.vqmodel = ChameleonVQVAE._from_config(config.vq_config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_image_tokens(self, pixel_values: torch.FloatTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() + n_image_features = image_tokens.shape[0] * image_tokens.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" + ) + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@add_start_docstrings( + "Chameleon Model with a head on top used for outputting logits for next token prediction.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ChameleonModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16) + >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation." + >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw) + >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw) + + >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + # Disallow image tokens which does not include special begin-image and end-image tokens + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, :, image_tokens] = torch.finfo(logits.dtype).min + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + pixel_values=None, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + pixel_values=pixel_values, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"] diff --git a/docs/transformers/build/lib/transformers/models/chameleon/processing_chameleon.py b/docs/transformers/build/lib/transformers/models/chameleon/processing_chameleon.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c592180e9f76f9c6ef3efcc69929206a81621d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chameleon/processing_chameleon.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Chameleon. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class ChameleonTextKwargs(TextKwargs, total=False): + return_for_text_completion: bool + + +class ChameleonProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: ChameleonTextKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_for_text_completion": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + +class ChameleonProcessor(ProcessorMixin): + r""" + Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single + processor. + + [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`]. + See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information. + + Args: + image_processor ([`ChameleonImageProcessor`]): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`]): + The tokenizer is a required input. + image_seq_length (`int`, *optional*, defaults to 1024): + Sequence length of one image embedding. + image_token (`str`, *optional*, defaults to `""`): + The special token used to indicate image in the text. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + valid_kwargs = ["image_seq_length", "image_token"] + image_processor_class = "ChameleonImageProcessor" + + def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): + self.image_seq_length = image_seq_length + self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_start_token = ( + tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "" + ) # fixed tokens for start and end, so can hardcode + self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "" + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ChameleonProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + if text is None and images is None: + raise ValueError("You must provide either text or images") + + output_kwargs = self._merge_kwargs( + ChameleonProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False) + + # Replace the image token with the expanded image token sequence + prompt_strings = [] + one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token + for sample in text: + sample = sample.replace(self.image_token, one_img_tokens) + if not return_for_text_completion: + sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode + prompt_strings.append(sample) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(prompt_strings, data, modalities=["image"]) + + if images is not None: + data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["ChameleonProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/configuration_chinese_clip.py b/docs/transformers/build/lib/transformers/models/chinese_clip/configuration_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c52b563cb2df9a63591c85d45b0aad99d53f4675 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/configuration_chinese_clip.py @@ -0,0 +1,434 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Chinese-CLIP model configuration""" + +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional + + +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ChineseCLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate a + Chinese CLIP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Chinese CLIP + [OFA-Sys/chinese-clip-vit-base-patch16](https: + //huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the CHINESE_CLIP model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`ChineseCLIPModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ChineseCLIPModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Example: + + ```python + >>> from transformers import ChineseCLIPTextConfig, ChineseCLIPTextModel + + >>> # Initializing a ChineseCLIPTextConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPTextConfig() + + >>> # Initializing a ChineseCLIPTextModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "chinese_clip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + initializer_factor=1.0, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + + +class ChineseCLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an + ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ChineseCLIP + [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + Example: + ```python + >>> from transformers import ChineseCLIPVisionConfig, ChineseCLIPVisionModel + + >>> # Initializing a ChineseCLIPVisionConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPVisionConfig() + + >>> # Initializing a ChineseCLIPVisionModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "chinese_clip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class ChineseCLIPConfig(PretrainedConfig): + r""" + [`ChineseCLIPConfig`] is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used + to instantiate Chinese-CLIP model according to the specified arguments, defining the text model and vision model + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + Chinese-CLIP [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ChineseCLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ChineseCLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original ChineseCLIP + implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ChineseCLIPConfig, ChineseCLIPModel + + >>> # Initializing a ChineseCLIPConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> configuration = ChineseCLIPConfig() + + >>> # Initializing a ChineseCLIPModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration + >>> model = ChineseCLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a ChineseCLIPConfig from a ChineseCLIPTextConfig and a ChineseCLIPVisionConfig + + >>> # Initializing a ChineseCLIPTextConfig and ChineseCLIPVisionConfig configuration + >>> config_text = ChineseCLIPTextConfig() + >>> config_vision = ChineseCLIPVisionConfig() + + >>> config = ChineseCLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "chinese_clip" + sub_configs = {"text_config": ChineseCLIPTextConfig, "vision_config": ChineseCLIPVisionConfig} + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = ChineseCLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `ChineseCLIPTextConfig`. " + f'The value `text_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = ChineseCLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize " + f'`ChineseCLIPVisionConfig`. The value `vision_config["{key}"]` will be overridden.' + ) + logger.info(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `ChineseCLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `ChineseCLIPVisionConfig` with default values.") + + self.text_config = ChineseCLIPTextConfig(**text_config) + self.vision_config = ChineseCLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + @classmethod + def from_text_vision_configs( + cls, text_config: ChineseCLIPTextConfig, vision_config: ChineseCLIPVisionConfig, **kwargs + ): + r""" + Instantiate a [`ChineseCLIPConfig`] (or a derived class) from Chinese-CLIP text model configuration and + Chinese-CLIP vision model configuration. Returns: + [`ChineseCLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +class ChineseCLIPOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + batch_size: int = -1, + seq_length: int = -1, + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + text_input_dict = super().generate_dummy_inputs( + processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + ) + image_input_dict = super().generate_dummy_inputs( + processor.image_processor, batch_size=batch_size, framework=framework + ) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 + + +__all__ = ["ChineseCLIPConfig", "ChineseCLIPOnnxConfig", "ChineseCLIPTextConfig", "ChineseCLIPVisionConfig"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py b/docs/transformers/build/lib/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..adc9300ef512507a9cf30d1c5cf79aef006a2f3f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from transformers import ChineseCLIPConfig, ChineseCLIPModel + + +def copy_attn_layer(hf_attn_layer, pt_weights, prefix): + q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0) + + out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"] + out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"] + + hf_attn_layer.q_proj.weight.data = q_proj + hf_attn_layer.q_proj.bias.data = q_proj_bias + + hf_attn_layer.k_proj.weight.data = k_proj + hf_attn_layer.k_proj.bias.data = k_proj_bias + + hf_attn_layer.v_proj.weight.data = v_proj + hf_attn_layer.v_proj.bias.data = v_proj_bias + + hf_attn_layer.out_proj.weight.data = out_proj_weights + hf_attn_layer.out_proj.bias.data = out_proj_bias + + +def copy_mlp(hf_mlp, pt_weights, prefix): + copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc") + copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj") + + +def copy_linear(hf_linear, pt_weights, prefix): + hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data + hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data + + +def copy_layer(hf_layer, pt_weights, prefix): + # copy layer norms + copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1") + copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2") + + # copy MLP + copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp") + + # copy attn + copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn") + + +def copy_layers(hf_layers, pt_weights, prefix): + for layer_id, hf_layer in enumerate(hf_layers): + copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}") + + +def copy_text_model_and_projection(hf_model, pt_weights): + # copy projection + hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T + + # copy text encoder + for name, param in hf_model.text_model.named_parameters(): + param.data = pt_weights[f"bert.{name}"].data + + +def copy_vision_model_and_projection(hf_model, pt_weights): + # copy projection + hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T + + # copy layer norms + copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre") + copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post") + + # copy embeddings + hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data + hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data + hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data + + # copy encoder + copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks") + + +@torch.no_grad() +def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + + assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size." + config = ChineseCLIPConfig.from_pretrained(config_path) + + hf_model = ChineseCLIPModel(config).eval() + + pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"] + pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()} + + copy_text_model_and_projection(hf_model, pt_weights) + copy_vision_model_and_projection(hf_model, pt_weights) + hf_model.logit_scale.data = pt_weights["logit_scale"].data + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output folder storing converted hf PyTorch model.", + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint." + ) + parser.add_argument( + "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert." + ) + args = parser.parse_args() + + convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) + print("The conversion is finished!") diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/feature_extraction_chinese_clip.py b/docs/transformers/build/lib/transformers/models/chinese_clip/feature_extraction_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c4895bb06b510cfeb64294759c31bcc8d0e3d098 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/feature_extraction_chinese_clip.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Chinese-CLIP.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_chinese_clip import ChineseCLIPImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ChineseCLIPFeatureExtractor(ChineseCLIPImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class ChineseCLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use ChineseCLIPImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["ChineseCLIPFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip.py b/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..d14d286b57d143a3b32b7967df9c97f83da81738 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -0,0 +1,314 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Chinese-CLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ChineseCLIPImageProcessor(BaseImageProcessor): + r""" + Constructs a Chinese-CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + """ + size = get_size_dict(size, default_to_square=False) + output_size = get_resize_output_image_size( + image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["ChineseCLIPImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py b/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cb38b8a25f726256e2c47a8b65890efd72361d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright 2025 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Chinese-CLIP.""" + +from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import add_start_docstrings + + +@add_start_docstrings( + "Constructs a fast ChineseCLIP image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class ChineseCLIPImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["ChineseCLIPImageProcessorFast"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/modeling_chinese_clip.py b/docs/transformers/build/lib/transformers/models/chinese_clip/modeling_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..647e8f1c2421261d76a5f1f87dfa10c91e5d2fcb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -0,0 +1,1630 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Chinese-CLIP model.""" + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "OFA-Sys/chinese-clip-vit-base-patch16" +_CONFIG_FOR_DOC = "ChineseCLIPConfig" + + +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def chinese_clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class ChineseCLIPOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of + [`ChineseCLIPTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`ChineseCLIPVisionModel`]. + text_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`ChineseCLIPTextModel`]. + vision_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`): + The output of the [`ChineseCLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText +class ChineseCLIPTextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->ChineseCLIP +class ChineseCLIPVisionEmbeddings(nn.Module): + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText +class ChineseCLIPTextSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->ChineseCLIPText +class ChineseCLIPTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = { + "eager": ChineseCLIPTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT +class ChineseCLIPTextAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = ChineseCLIPTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ChineseCLIPVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText +class ChineseCLIPTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->ChineseCLIPText +class ChineseCLIPTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->ChineseCLIPVision +class ChineseCLIPVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText +class ChineseCLIPTextLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ChineseCLIPTextAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute") + self.intermediate = ChineseCLIPTextIntermediate(config) + self.output = ChineseCLIPTextOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class ChineseCLIPVisionLayer(nn.Module): + def __init__(self, config: ChineseCLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ChineseCLIPVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = ChineseCLIPVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->ChineseCLIPText +class ChineseCLIPTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ChineseCLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ChineseCLIPConfig + base_model_prefix = "chinese_clip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, ChineseCLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, ChineseCLIPTextEmbeddings): + nn.init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range) + for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]: + if embedding.padding_idx is not None: + embedding.weight.data[embedding.padding_idx].zero_() + elif isinstance(module, ChineseCLIPVisionAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, ChineseCLIPVisionMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, ChineseCLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +CHINESE_CLIP_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ChineseCLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHINESE_CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CHINESE_CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CHINESE_CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText +class ChineseCLIPTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class ChineseCLIPVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`ChineseCLIPVisionEncoderLayer`]. + + Args: + config: ChineseCLIPConfig + """ + + def __init__(self, config: ChineseCLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class ChineseCLIPVisionTransformer(nn.Module): + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = ChineseCLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = ChineseCLIPVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The text model from CHINESE_CLIP without any head or projection on top.", + CHINESE_CLIP_START_DOCSTRING, +) +class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + config_class = ChineseCLIPTextConfig + _no_split_modules = ["ChineseCLIPTextEmbeddings"] + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = ChineseCLIPTextEmbeddings(config) + self.encoder = ChineseCLIPTextEncoder(config) + + self.pooler = ChineseCLIPTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """The vision model from CHINESE_CLIP without any head or projection on top.""", + CHINESE_CLIP_START_DOCSTRING, +) +class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel): + config_class = ChineseCLIPVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"] + + def __init__(self, config: ChineseCLIPVisionConfig): + super().__init__(config) + self.vision_model = ChineseCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import CLIPProcessor, ChineseCLIPVisionModel + + >>> model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = CLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@add_start_docstrings(CHINESE_CLIP_START_DOCSTRING) +class ChineseCLIPModel(ChineseCLIPPreTrainedModel): + config_class = ChineseCLIPConfig + + def __init__(self, config: ChineseCLIPConfig): + super().__init__(config) + + if not isinstance(config.text_config, ChineseCLIPTextConfig): + raise TypeError( + "config.text_config is expected to be of type ChineseCLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, ChineseCLIPVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type ChineseCLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = ChineseCLIPTextModel(text_config, add_pooling_layer=False) + self.vision_model = ChineseCLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the final [CLS] hidden state of Text-Transformer. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ChineseCLIPModel + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> tokenizer = AutoTokenizer.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> inputs = tokenizer(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + >>> text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) + ```""" + # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[0][:, 0, :] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the final [CLS] hidden state of Vision-Transformer. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ChineseCLIPModel + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + >>> image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) + ```""" + # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ChineseCLIPOutput, config_class=ChineseCLIPConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ChineseCLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, ChineseCLIPModel + + >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") + + >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, return_tensors="pt", padding=True) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[0][:, 0, :] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = chinese_clip_loss(logits_per_text) + + if not return_dict: + # fix the None pooled_output of text_outputs to conform with dict_output + pooled_output = text_outputs[1] + if pooled_output is None: + text_outputs = (text_outputs[0],) + text_outputs[2:] + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return ChineseCLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["ChineseCLIPModel", "ChineseCLIPPreTrainedModel", "ChineseCLIPTextModel", "ChineseCLIPVisionModel"] diff --git a/docs/transformers/build/lib/transformers/models/chinese_clip/processing_chinese_clip.py b/docs/transformers/build/lib/transformers/models/chinese_clip/processing_chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..3523c782f3ac38cc2e9a327a1f5fd8759f6f0141 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/chinese_clip/processing_chinese_clip.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for Chinese-CLIP +""" + +import warnings +from typing import List, Union + +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class ChineseClipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +class ChineseCLIPProcessor(ProcessorMixin): + r""" + Constructs a Chinese-CLIP processor which wraps a Chinese-CLIP image processor and a Chinese-CLIP tokenizer into a + single processor. + + [`ChineseCLIPProcessor`] offers all the functionalities of [`ChineseCLIPImageProcessor`] and [`BertTokenizerFast`]. + See the [`~ChineseCLIPProcessor.__call__`] and [`~ChineseCLIPProcessor.decode`] for more information. + + Args: + image_processor ([`ChineseCLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`BertTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast") + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + audio=None, + videos=None, + **kwargs: Unpack[ChineseClipProcessorKwargs], + ) -> BatchEncoding: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + output_kwargs = self._merge_kwargs( + ChineseClipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if text is not None: + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + # BC for explicit return_tensors + if "return_tensors" in output_kwargs["common_kwargs"]: + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + +__all__ = ["ChineseCLIPProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/clap/__init__.py b/docs/transformers/build/lib/transformers/models/clap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d54ee86aecef2cbe5b9bfdee321a0375d977880 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_clap import * + from .feature_extraction_clap import * + from .modeling_clap import * + from .processing_clap import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/clap/configuration_clap.py b/docs/transformers/build/lib/transformers/models/clap/configuration_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b7d3b7a21a96ca93707e64858edc5584ae9303 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/configuration_clap.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CLAP model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ClapTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClapTextModel`]. It is used to instantiate a CLAP + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the CLAP + [calp-hsat-fused](https://huggingface.co/laion/clap-hsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the CLAP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ClapTextModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"relu"`, + `"relu"`, `"silu"` and `"relu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ClapTextModel`]. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + projection_dim (`int`, *optional*, defaults to 512) + Dimension of the projection head of the `ClapTextModelWithProjection`. + + Examples: + + ```python + >>> from transformers import ClapTextConfig, ClapTextModel + + >>> # Initializing a CLAP text configuration + >>> configuration = ClapTextConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = ClapTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clap_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=514, + type_vocab_size=1, + initializer_factor=1.0, + layer_norm_eps=1e-12, + projection_dim=512, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + projection_hidden_act="relu", + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_factor = initializer_factor + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.projection_hidden_act = projection_hidden_act + self.projection_dim = projection_dim + + +class ClapAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClapAudioModel`]. It is used to instantiate a + CLAP audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the CLAP + [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + window_size (`int`, *optional*, defaults to 8): + Image size of the spectrogram + num_mel_bins (`int`, *optional*, defaults to 64): + Number of mel features used per frames. Should correspond to the value used in the `ClapProcessor` class. + spec_size (`int`, *optional*, defaults to 256): + Desired input size of the spectrogram that the model supports. It can be different from the output of the + `ClapFeatureExtractor`, in which case the input features will be resized. Corresponds to the `image_size` + of the audio models. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + patch_size (`int`, *optional*, defaults to 4): + Patch size for the audio spectrogram + patch_stride (`list`, *optional*, defaults to `[4, 4]`): + Patch stride for the audio spectrogram + num_classes (`int`, *optional*, defaults to 527): + Number of classes used for the head training + hidden_size (`int`, *optional*, defaults to 768): + Hidden size of the output of the audio encoder. Correspond to the dimension of the penultimate layer's + output,which is sent to the projection MLP layer. + projection_dim (`int`, *optional*, defaults to 512): + Hidden size of the projection layer. + depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`): + Depths used for the Swin Layers of the audio model + num_attention_heads (`list`, *optional*, defaults to `[4, 8, 16, 32]`): + Number of attention heads used for the Swin Layers of the audio model + enable_fusion (`bool`, *optional*, defaults to `False`): + Whether or not to enable patch fusion. This is the main contribution of the authors, and should give the + best results. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the encoder. + fusion_type (`[type]`, *optional*): + Fusion type used for the patch fusion. + patch_embed_input_channels (`int`, *optional*, defaults to 1): + Number of channels used for the input spectrogram + flatten_patch_embeds (`bool`, *optional*, defaults to `True`): + Whether or not to flatten the patch embeddings + patch_embeds_hidden_size (`int`, *optional*, defaults to 96): + Hidden size of the patch embeddings. It is used as the number of output channels. + enable_patch_layer_norm (`bool`, *optional*, defaults to `True`): + Whether or not to enable layer normalization for the patch embeddings + drop_path_rate (`float`, *optional*, defaults to 0.0): + Drop path rate for the patch fusion + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to add a bias to the query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of the mlp hidden dim to embedding dim. + aff_block_r (`int`, *optional*, defaults to 4): + downsize_ratio used in the AudioFF block + num_hidden_layers (`int`, *optional*, defaults to 4): + Number of hidden layers in the Transformer encoder. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + layer_norm_eps (`[type]`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import ClapAudioConfig, ClapAudioModel + + >>> # Initializing a ClapAudioConfig with laion/clap-htsat-fused style configuration + >>> configuration = ClapAudioConfig() + + >>> # Initializing a ClapAudioModel (with random weights) from the laion/clap-htsat-fused style configuration + >>> model = ClapAudioModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clap_audio_model" + base_config_key = "audio_config" + + def __init__( + self, + window_size=8, + num_mel_bins=64, + spec_size=256, + hidden_act="gelu", + patch_size=4, + patch_stride=[4, 4], + num_classes=527, + hidden_size=768, + projection_dim=512, + depths=[2, 2, 6, 2], + num_attention_heads=[4, 8, 16, 32], + enable_fusion=False, + hidden_dropout_prob=0.1, + fusion_type=None, + patch_embed_input_channels=1, + flatten_patch_embeds=True, + patch_embeds_hidden_size=96, + enable_patch_layer_norm=True, + drop_path_rate=0.0, + attention_probs_dropout_prob=0.0, + qkv_bias=True, + mlp_ratio=4.0, + aff_block_r=4, + num_hidden_layers=4, + projection_hidden_act="relu", + layer_norm_eps=1e-5, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.window_size = window_size + self.num_mel_bins = num_mel_bins + self.spec_size = spec_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.num_classes = num_classes + self.hidden_size = hidden_size + self.depths = depths + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.window_size = window_size + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.projection_dim = projection_dim + self.flatten_patch_embeds = flatten_patch_embeds + self.patch_embeds_hidden_size = patch_embeds_hidden_size + self.enable_patch_layer_norm = enable_patch_layer_norm + self.drop_path_rate = drop_path_rate + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.patch_embed_input_channels = patch_embed_input_channels + self.aff_block_r = aff_block_r + self.layer_norm_eps = layer_norm_eps + self.initializer_factor = initializer_factor + self.projection_hidden_act = projection_hidden_act + + +class ClapConfig(PretrainedConfig): + r""" + [`ClapConfig`] is the configuration class to store the configuration of a [`ClapModel`]. It is used to instantiate + a CLAP model according to the specified arguments, defining the text model and audio model configs. Instantiating a + configuration with the defaults will yield a similar configuration to that of the CLAP + [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ClapTextConfig`]. + audio_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ClapAudioConfig`]. + logit_scale_init_value (`float`, *optional*, defaults to 14.29): + The initial value of the *logit_scale* parameter. Default is used as per the original CLAP implementation. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and audio projection layers. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + Activation function for the projection layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + Factor to scale the initialization of the model weights. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ClapConfig, ClapModel + + >>> # Initializing a ClapConfig with laion-ai/base style configuration + >>> configuration = ClapConfig() + + >>> # Initializing a ClapModel (with random weights) from the laion-ai/base style configuration + >>> model = ClapModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a ClapConfig from a ClapTextConfig and a ClapAudioConfig + >>> from transformers import ClapTextConfig, ClapAudioConfig + + >>> # Initializing a ClapText and ClapAudioConfig configuration + >>> config_text = ClapTextConfig() + >>> config_audio = ClapAudioConfig() + + >>> config = ClapConfig.from_text_audio_configs(config_text, config_audio) + ```""" + + model_type = "clap" + sub_configs = {"text_config": ClapTextConfig, "audio_config": ClapAudioConfig} + + def __init__( + self, + text_config=None, + audio_config=None, + logit_scale_init_value=(1 / 0.07), + projection_dim=512, + projection_hidden_act="relu", + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the ClapTextConfig with default values.") + + if audio_config is None: + audio_config = {} + logger.info("audio_config is None. initializing the ClapAudioConfig with default values.") + + self.text_config = ClapTextConfig(**text_config) + self.audio_config = ClapAudioConfig(**audio_config) + self.text_config.projection_dim = projection_dim + self.audio_config.projection_dim = projection_dim + + self.text_config.projection_hidden_act = projection_hidden_act + self.audio_config.projection_hidden_act = projection_hidden_act + + self.projection_dim = projection_dim + self.projection_hidden_act = projection_hidden_act + self.hidden_size = self.text_config.hidden_size + + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = initializer_factor + self.num_hidden_layers = self.text_config.num_hidden_layers + len(self.audio_config.depths) + + @classmethod + def from_text_audio_configs(cls, text_config: ClapTextConfig, audio_config: ClapAudioConfig, **kwargs): + r""" + Instantiate a [`ClapConfig`] (or a derived class) from clap text model configuration and clap audio model + configuration. + + Returns: + [`ClapConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs) + + +__all__ = ["ClapAudioConfig", "ClapConfig", "ClapTextConfig"] diff --git a/docs/transformers/build/lib/transformers/models/clap/convert_clap_original_pytorch_to_hf.py b/docs/transformers/build/lib/transformers/models/clap/convert_clap_original_pytorch_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..66488e401a1a28817e892d3578f425b6c378fb75 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/convert_clap_original_pytorch_to_hf.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re + +from laion_clap import CLAP_Module + +from transformers import AutoFeatureExtractor, ClapConfig, ClapModel + + +KEYS_TO_MODIFY_MAPPING = { + "text_branch": "text_model", + "audio_branch": "audio_model.audio_encoder", + "attn": "attention.self", + "self.proj": "output.dense", + "attention.self_mask": "attn_mask", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm1": "layernorm_before", + "norm2": "layernorm_after", + "bn0": "batch_norm", +} + +processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc") + + +def init_clap(checkpoint_path, model_type, enable_fusion=False): + model = CLAP_Module( + amodel=model_type, + enable_fusion=enable_fusion, + ) + model.load_ckpt(checkpoint_path) + return model + + +def get_config_from_original(clap_model): + audio_config = { + "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim, + "depths": clap_model.model.audio_branch.depths, + "hidden_size": clap_model.model.audio_projection[0].in_features, + } + + text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features} + + return ClapConfig(audio_config=audio_config, text_config=text_config) + + +def rename_state_dict(state_dict): + model_state_dict = {} + + sequential_layers_pattern = r".*sequential.(\d+).*" + text_projection_pattern = r".*_projection.(\d+).*" + + for key, value in state_dict.items(): + # check if any key needs to be modified + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(sequential_layers_pattern, key): + # replace sequential layers with list + sequential_layer = re.match(sequential_layers_pattern, key).group(1) + + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.") + elif re.match(text_projection_pattern, key): + projecton_layer = int(re.match(text_projection_pattern, key).group(1)) + + # Because in CLAP they use `nn.Sequential`... + transformers_projection_layer = 1 if projecton_layer == 0 else 2 + + key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") + + if "audio" and "qkv" in key: + # split qkv into query key and value + mixed_qkv = value + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + model_state_dict[key.replace("qkv", "query")] = query_layer + model_state_dict[key.replace("qkv", "key")] = key_layer + model_state_dict[key.replace("qkv", "value")] = value_layer + else: + model_state_dict[key] = value + + return model_state_dict + + +def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False): + clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion) + + clap_model.eval() + state_dict = clap_model.model.state_dict() + state_dict = rename_state_dict(state_dict) + + transformers_config = get_config_from_original(clap_model) + transformers_config.audio_config.enable_fusion = enable_fusion + model = ClapModel(transformers_config) + + # ignore the spectrogram embedding layer + model.load_state_dict(state_dict, strict=False) + + model.save_pretrained(pytorch_dump_folder_path) + transformers_config.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not") + parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not") + args = parser.parse_args() + + convert_clap_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion + ) diff --git a/docs/transformers/build/lib/transformers/models/clap/feature_extraction_clap.py b/docs/transformers/build/lib/transformers/models/clap/feature_extraction_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe51cab7293db1482d4bda727299fb197579435 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/feature_extraction_clap.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for CLAP.""" + +import copy +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch",)) +class ClapFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a CLAP feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time + Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 64): + The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters + (`n_mels`). + sampling_rate (`int`, *optional*, defaults to 48000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves + to warn users if the audio fed to the feature extractor does not have the same sampling rate. + hop_length (`int`,*optional*, defaults to 480): + Length of the overlaping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split + in smaller `frames` with a step of `hop_length` between each frame. + max_length_s (`int`, *optional*, defaults to 10): + The maximum input length of the model in seconds. This is used to pad the audio. + fft_window_size (`int`, *optional*, defaults to 1024): + Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency + resolution of the spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the attention masks coresponding to the input. + frequency_min (`float`, *optional*, defaults to 0): + The lowest frequency of interest. The STFT will not be computed for values below this. + frequency_max (`float`, *optional*, defaults to 14000): + The highest frequency of interest. The STFT will not be computed for values above this. + top_db (`float`, *optional*): + The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the + `audio_utils.power_to_db` function + truncation (`str`, *optional*, defaults to `"fusion"`): + Truncation pattern for long audio inputs. Two patterns are available: + - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a + downsampled version of the entire mel spectrogram. + If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy + of the original mel obtained from the padded audio. + - `rand_trunc` will select a random crop of the mel spectrogram. + padding (`str`, *optional*, defaults to `"repeatpad"`): + Padding pattern for shorter audio inputs. Three patterns were originally implemented: + - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. + - `repeat`: the audio is repeated and then cut to fit the `max_length` + - `pad`: the audio is padded. + """ + + model_input_names = ["input_features", "is_longer"] + + def __init__( + self, + feature_size=64, + sampling_rate=48_000, + hop_length=480, + max_length_s=10, + fft_window_size=1024, + padding_value=0.0, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + frequency_min: float = 0, + frequency_max: float = 14_000, + top_db: Optional[int] = None, + truncation: str = "fusion", + padding: str = "repeatpad", + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.top_db = top_db + self.truncation = truncation + self.padding = padding + self.fft_window_size = fft_window_size + self.nb_frequency_bins = (fft_window_size >> 1) + 1 + self.hop_length = hop_length + self.max_length_s = max_length_s + self.nb_max_samples = max_length_s * sampling_rate + self.sampling_rate = sampling_rate + self.frequency_min = frequency_min + self.frequency_max = frequency_max + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.nb_frequency_bins, + num_mel_filters=feature_size, + min_frequency=frequency_min, + max_frequency=frequency_max, + sampling_rate=sampling_rate, + norm=None, + mel_scale="htk", + ) + self.mel_filters_slaney = mel_filter_bank( + num_frequency_bins=self.nb_frequency_bins, + num_mel_filters=feature_size, + min_frequency=frequency_min, + max_frequency=frequency_max, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, excpet for the + mel filter banks, which do not need to be saved or printed as they are too long. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + if "mel_filters_slaney" in output: + del output["mel_filters_slaney"] + return output + + def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray: + """ + Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter + banks are used depending on the truncation pattern: + - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from + calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` + is set to `"fusion"`. + - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used + `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original + implementation when the truncation mode is not `"fusion"`. + """ + log_mel_spectrogram = spectrogram( + waveform, + window_function(self.fft_window_size, "hann"), + frame_length=self.fft_window_size, + hop_length=self.hop_length, + power=2.0, + mel_filters=mel_filters, + log_mel="dB", + ) + return log_mel_spectrogram.T + + def _random_mel_fusion(self, mel, total_frames, chunk_frames): + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + mel = torch.tensor(mel[None, None, :]) + mel_shrink = torch.nn.functional.interpolate( + mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False + ) + mel_shrink = mel_shrink[0][0].numpy() + mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) + return mel_fusion + + def _get_input_mel(self, waveform: np.array, max_length, truncation, padding) -> np.array: + """ + Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. + Four different path are possible: + - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram + will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram + are then stacked together. They will later be used for `feature_fusion`. + - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is + padded based on `padding`. + - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded + based on `padding`, and is repeated `4` times. + - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel + spectrogram will be computed on a random crop of the waveform. + + """ + if waveform.shape[0] > max_length: + if truncation == "rand_trunc": + longer = True + # random crop to max_length (for compatibility) -> this should be handled by self.pad + overflow = len(waveform) - max_length + idx = np.random.randint(0, overflow + 1) + waveform = waveform[idx : idx + max_length] + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + elif truncation == "fusion": + mel = self._np_extract_fbank_features(waveform, self.mel_filters) + chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length. + # In this case, we just use the whole audio. + input_mel = np.stack([mel, mel, mel, mel], axis=0) + longer = False + else: + input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) + longer = True + else: + raise NotImplementedError(f"data_truncating {truncation} not implemented") + + else: + longer = False + # only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding + if waveform.shape[0] < max_length: + if padding == "repeat": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat + 1)[:max_length] + if padding == "repeatpad": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat) + waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) + + if truncation == "fusion": + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) + input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) + else: + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + + return input_mel, longer + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + truncation: Optional[str] = None, + padding: Optional[str] = None, + max_length: Optional[int] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + truncation (`str`, *optional*): + Truncation pattern for long audio inputs. Two patterns are available: + - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and + a downsampled version of the entire mel spectrogram. + If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a + copy of the original mel obtained from the padded audio. + - `rand_trunc` will select a random crop of the mel spectrogram. + padding (`str`, *optional*): + Padding pattern for shorter audio inputs. Three patterns were originally implemented: + - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. + - `repeat`: the audio is repeated and then cut to fit the `max_length` + - `pad`: the audio is padded. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.np.array` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. + """ + truncation = truncation if truncation is not None else self.truncation + padding = padding if padding else self.padding + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float64) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float64) + + # always return batch + if not is_batched: + raw_speech = [np.asarray(raw_speech)] + + # convert to mel spectrogram, truncate and pad if needed. + padded_inputs = [ + self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding) + for waveform in raw_speech + ] + + input_mel = [] + is_longer = [] + for mel, longer in padded_inputs: + input_mel.append(mel) + is_longer.append(longer) + + if truncation == "fusion" and sum(is_longer) == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + rand_idx = np.random.randint(0, len(input_mel)) + is_longer[rand_idx] = True + + if isinstance(input_mel[0], List): + input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel] + + # is_longer is a list of bool + is_longer = [[longer] for longer in is_longer] + + input_features = {"input_features": input_mel, "is_longer": is_longer} + input_features = BatchFeature(input_features) + + if return_tensors is not None: + input_features = input_features.convert_to_tensors(return_tensors) + + return input_features + + +__all__ = ["ClapFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/clap/modeling_clap.py b/docs/transformers/build/lib/transformers/models/clap/modeling_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a51cc86af32e9536ee5ca9f238228c4250548e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/modeling_clap.py @@ -0,0 +1,2314 @@ +# coding=utf-8 +# Copyright 2023 The LAION-AI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CLAP model.""" + +import collections +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "laion/clap-htsat-fused" + + +# Adapted from: https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/utils.py#L191 +def interpolate(hidden_states, ratio): + """ + Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN. + + Args: + hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)): + Input hidden states + ratio (`int`): + The ratio of the length of the output to the length of the input. + """ + (batch_size, time_length, classes_num) = hidden_states.shape + upsampled = hidden_states[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_length * ratio, classes_num) + return upsampled + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L249 +def window_partition(hidden_states, window_size): + """ + Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size, + num_channels)` + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`): + Input hidden states + window_size (`int`): + Window size + """ + batch_size, height, width, num_channels = hidden_states.shape + + hidden_states = hidden_states.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L263 +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + Args: + windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`): + Input windows + window_size (`int`): + Window size + height (`int`): + Height of the resized audio + width (`int`): + Width of the resized audio + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + labels = torch.arange(len(logits), device=logits.device) + return nn.functional.cross_entropy(logits, labels) + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Clap +class ClapTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ClapAudioModelOutput(ModelOutput): + """ + ClapAudio model output to mimic the output of the original implementation. + + Args: + audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + The Audio embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + audio_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Clap, vision->audio, Vision->Audio, image->audio +class ClapOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for audio-text similarity. + logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`): + The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`): + The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`]. + audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`ClapTextModel`]. + audio_model_output (`BaseModelOutputWithPooling`): + The output of the [`ClapAudioModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_audio: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + audio_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + audio_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "audio_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Adapted from transformers.models.swin.modeling_swin.SwinDropPath +class ClapDropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly + refactored version of the `SwinDropPath` implementation. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states): + if self.drop_prob == 0.0 or not self.training: + return hidden_states + + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + shape = (hidden_states.shape[0],) + (1,) * (hidden_states.ndim - 1) + + random_tensor = keep_prob + torch.rand(shape, dtype=hidden_states.dtype, device=hidden_states.device) + random_tensor.floor_() # binarize + output = hidden_states.div(keep_prob) * random_tensor + return output + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/feature_fusion.py#L133 +class ClapAudioAFFBlock(nn.Module): + r""" + ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement + the 1D version. + """ + + def __init__(self, config: ClapAudioConfig): + super().__init__() + channels = config.patch_embeds_hidden_size + downsize_ratio = config.aff_block_r + inter_channels = int(channels // downsize_ratio) + + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states, residual): + attention_input = hidden_states + residual + + fused_layer_output = self.local_att(attention_input) + self.global_att(attention_input) + fused_layer_output = self.sigmoid(fused_layer_output) + + output = 2 * hidden_states * fused_layer_output + 2 * residual * (1 - fused_layer_output) + return output + + +class ClapAudioPatchEmbed(nn.Module): + """ + This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the + Transformer block. + """ + + def __init__(self, config: ClapAudioConfig): + super().__init__() + img_size = (config.spec_size, config.spec_size) if isinstance(config.spec_size, int) else config.spec_size + patch_size = ( + (config.patch_size, config.patch_size) if isinstance(config.patch_size, int) else config.patch_size + ) + patch_stride = ( + (config.patch_stride, config.patch_stride) if isinstance(config.patch_stride, int) else config.patch_stride + ) + + self.img_size = img_size + self.patch_stride = patch_stride + + self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.flatten = config.flatten_patch_embeds + self.enable_fusion = config.enable_fusion + + padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) + + scale_factor = 4 if (self.enable_fusion) and (config.fusion_type == "channel_map") else 1 + + self.proj = nn.Conv2d( + config.patch_embed_input_channels * scale_factor, + config.patch_embeds_hidden_size, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + + self.norm = nn.LayerNorm(config.patch_embeds_hidden_size) if config.enable_patch_layer_norm else nn.Identity() + if self.enable_fusion: + self.fusion_model = ClapAudioAFFBlock(config) + self.mel_conv2d = nn.Conv2d( + config.patch_embed_input_channels, + config.patch_embeds_hidden_size, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + + def forward(self, hidden_states, is_longer_idx=None): + if self.enable_fusion: + # retrieve the last mel as we have transposed the input + global_hidden_states = hidden_states[:, 0:1, :, :] + + # global processing + batch_size, num_channels, height, width = global_hidden_states.shape + + if height != self.img_size[0] or width != self.img_size[1]: + raise ValueError( + f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) + + global_hidden_states = self.proj(global_hidden_states) + output_width = global_hidden_states.size(-1) + if len(is_longer_idx) > 0: + # local processing + local_hidden_states = hidden_states[is_longer_idx, 1:, :, :].contiguous() + batch_size, num_channels, height, width = local_hidden_states.shape + local_hidden_states = local_hidden_states.view(batch_size * num_channels, 1, height, width) + + local_hidden_states = self.mel_conv2d(local_hidden_states) + + _, features, height, width = local_hidden_states.shape + local_hidden_states = local_hidden_states.view(batch_size, num_channels, features, height, width) + local_hidden_states = local_hidden_states.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + + local_width = local_hidden_states.size(-1) + local_hidden_states = torch.nn.functional.pad( + local_hidden_states, (0, output_width - local_width), "constant", 0 + ) + + global_hidden_states[is_longer_idx] = self.fusion_model( + global_hidden_states[is_longer_idx], local_hidden_states + ) + hidden_states = global_hidden_states + else: + _, _, height, width = hidden_states.shape + if height != self.img_size[0] or width != self.img_size[1]: + raise ValueError( + f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) + hidden_states = self.proj(hidden_states) + + if self.flatten: + hidden_states = hidden_states.flatten(2).transpose(1, 2) + hidden_states = self.norm(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->ClapAudio +class ClapAudioSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ClapAudioModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->ClapAudio +class ClapAudioSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->ClapAudio +class ClapAudioAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = ClapAudioSelfAttention(config, dim, num_heads, window_size) + self.output = ClapAudioSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->ClapAudio +class ClapAudioIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->ClapAudio +class ClapAudioOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio +class ClapAudioLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = ClapAudioIntermediate(config, dim) + self.output = ClapAudioOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio +class ClapAudioStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + ClapAudioLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging with Swin->ClapAudio +class ClapAudioPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +class ClapAudioEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.num_layers = len(config.depths) + + self.config = config + self.patch_embed = ClapAudioPatchEmbed(config) + self.enable_fusion = config.enable_fusion + self.patch_stride = self.patch_embed.patch_stride + self.spec_size = config.spec_size + self.freq_ratio = config.spec_size // config.num_mel_bins + + self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1)) + + drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + + grid_size = self.patch_embed.grid_size + self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)] + + self.layers = nn.ModuleList( + [ + ClapAudioStage( + config=config, + dim=int(config.patch_embeds_hidden_size * 2**i_layer), + input_resolution=self.input_resolutions[i_layer], + depth=config.depths[i_layer], + num_heads=config.num_attention_heads[i_layer], + drop_path=drop_path_rate[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=ClapAudioPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + self.batch_norm = nn.BatchNorm2d(config.num_mel_bins) + self.norm = nn.LayerNorm(self.num_features) + self.depths = config.depths + self.avgpool = nn.AdaptiveAvgPool1d(1) + + def reshape_mel2img(self, normalized_input_features): + """ + The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel + should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`]. + """ + _, _, time_length, freq_length = normalized_input_features.shape + + spec_width = int(self.spec_size * self.freq_ratio) + spec_heigth = self.spec_size // self.freq_ratio + + if time_length > spec_width or freq_length > spec_heigth: + raise ValueError("the wav size should be less than or equal to the swin input size") + + # to avoid bicubic zero error + if time_length < spec_width: + normalized_input_features = nn.functional.interpolate( + normalized_input_features, (spec_width, freq_length), mode="bicubic", align_corners=True + ) + if freq_length < spec_heigth: + normalized_input_features = nn.functional.interpolate( + normalized_input_features, (time_length, spec_heigth), mode="bicubic", align_corners=True + ) + + batch, channels, time, freq = normalized_input_features.shape + + # batch_size, channels, spec_width, spec_heigth --> batch_size, channels, spec_heigth * freq_ratio, spec_width // freq_ratio + normalized_input_features = normalized_input_features.reshape( + batch, channels * self.freq_ratio, time // self.freq_ratio, freq + ) + normalized_input_features = normalized_input_features.permute(0, 1, 3, 2).contiguous() + normalized_input_features = normalized_input_features.reshape( + batch, channels, freq * self.freq_ratio, time // self.freq_ratio + ) + + return normalized_input_features + + def forward( + self, + input_features, + is_longer: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, ClapAudioModelOutput]: + input_features = input_features.transpose(1, 3) + normalized_input_features = self.batch_norm(input_features) + normalized_input_features = normalized_input_features.transpose(1, 3) + + is_longer_list_idx = None + if self.enable_fusion: + is_longer_list = is_longer.to(input_features.device) + is_longer_list_idx = torch.where(is_longer_list == 1)[0] + + hidden_states = self.reshape_mel2img(normalized_input_features) + + frames_num = hidden_states.shape[2] + + hidden_states = self.patch_embed(hidden_states, is_longer_list_idx) + + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + input_dimensions = self.input_resolutions[0] + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + input_dimensions = self.input_resolutions[i] + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + ) + else: + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + last_hidden_state = self.norm(hidden_states) + + batch_size, _, n_channels = last_hidden_state.shape + + freq_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + temporal_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + + last_hidden_state = ( + last_hidden_state.permute(0, 2, 1).contiguous().reshape(batch_size, n_channels, freq_shape, temporal_shape) + ) + + batch_size, n_channels, n_frequencies, n_temp = last_hidden_state.shape + # group 2D CNN + c_freq_bin = n_frequencies // self.freq_ratio + last_hidden_state = last_hidden_state.reshape( + batch_size, n_channels, n_frequencies // c_freq_bin, c_freq_bin, n_temp + ) + last_hidden_state = ( + last_hidden_state.permute(0, 1, 3, 2, 4).contiguous().reshape(batch_size, n_channels, c_freq_bin, -1) + ) + latent_output = self.avgpool(torch.flatten(last_hidden_state, 2)) + latent_output = torch.flatten(latent_output, 1) + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + latent_output, + all_reshaped_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=latent_output, + hidden_states=all_reshaped_hidden_states, + attentions=all_self_attentions, + ) + + +CLAP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ClapConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLAP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLAP_AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input audio features. This should be returnes by the [`ClapFeatureExtractor`] class that you can also + retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details. + is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*): + Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance + the features. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLAP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input audio features. This should be returnes by the [`ClapFeatureExtractor`] class that you can also + retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class ClapProjectionLayer(nn.Module): + def __init__(self, config: Union[ClapAudioConfig, ClapTextConfig]): + super().__init__() + self.config = config + hidden_size = config.hidden_size + projection_dim = config.projection_dim + + self.linear1 = nn.Linear(hidden_size, projection_dim) + self.activation = ACT2FN[config.projection_hidden_act] + self.linear2 = nn.Linear(projection_dim, projection_dim) + + def forward(self, hidden_states): + hidden_states = self.linear1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->ClapText, persistent=False->persistent=True +class ClapTextEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText +class ClapTextSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class ClapTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +CLAP_TEXT_SELF_ATTENTION_CLASSES = { + "eager": ClapTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT +class ClapTextAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = ClapTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class ClapTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class ClapTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText +class ClapTextLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ClapTextAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ClapTextAttention(config, position_embedding_type="absolute") + self.intermediate = ClapTextIntermediate(config) + self.output = ClapTextOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText +class ClapTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class ClapTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ClapPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ClapConfig + base_model_prefix = "clap" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + + if isinstance(module, ClapTextEmbeddings): + module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, ClapModel): + nn.init.normal_(module.logit_scale_a, std=factor * 0.02) + nn.init.normal_(module.logit_scale_t, std=factor * 0.02) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor * 0.02) + + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Conv2d, nn.Linear)): + in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor + nn.init.normal_(module.weight, std=in_proj_std) + if module.bias is not None: + module.bias.data.zero_() + + +class ClapAudioModel(ClapPreTrainedModel): + config_class = ClapAudioConfig + main_input_name = "input_features" + + def __init__(self, config: ClapAudioConfig): + super().__init__(config) + self.audio_encoder = ClapAudioEncoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_encoder.patch_embed.proj + + @add_start_docstrings_to_model_forward(CLAP_AUDIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ClapAudioConfig) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ClapAudioModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused") + >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused") + + >>> inputs = processor(audios=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return self.audio_encoder( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ClapTextModel(ClapPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + + """ + + config_class = ClapTextConfig + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = ClapTextEmbeddings(config) + self.encoder = ClapTextEncoder(config) + + self.pooler = ClapTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings(CLAP_START_DOCSTRING) +class ClapModel(ClapPreTrainedModel): + config_class = ClapConfig + + def __init__(self, config: ClapConfig): + super().__init__(config) + + if not isinstance(config.text_config, ClapTextConfig): + raise TypeError( + "config.text_config is expected to be of type ClapTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.audio_config, ClapAudioConfig): + raise TypeError( + "config.audio_config is expected to be of type ClapAudioConfig but is of type" + f" {type(config.audio_config)}." + ) + + text_config = config.text_config + audio_config = config.audio_config + + self.logit_scale_a = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value))) + self.logit_scale_t = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value))) + + self.projection_dim = config.projection_dim + + self.text_model = ClapTextModel(text_config) + self.text_projection = ClapProjectionLayer(text_config) + + self.audio_model = ClapAudioModel(audio_config) + self.audio_projection = ClapProjectionLayer(audio_config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CLAP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`ClapTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ClapModel + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLAP model's config for some fields (if specified) instead of those of audio & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] if return_dict is not None else text_outputs.pooler_output + text_features = self.text_projection(pooled_output) + text_features = F.normalize(text_features, dim=-1) + + return text_features + + @add_start_docstrings_to_model_forward(CLAP_AUDIO_INPUTS_DOCSTRING) + def get_audio_features( + self, + input_features: Optional[torch.Tensor] = None, + is_longer: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by + applying the projection layer to the pooled output of [`ClapAudioModel`]. + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, ClapModel + >>> import torch + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused") + >>> random_audio = torch.rand((16_000)) + >>> inputs = feature_extractor(random_audio, return_tensors="pt") + >>> audio_features = model.get_audio_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + audio_outputs = self.audio_model( + input_features=input_features, + is_longer=is_longer, + return_dict=return_dict, + ) + + pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output + + audio_features = self.audio_projection(pooled_output) + audio_features = F.normalize(audio_features, dim=-1) + + return audio_features + + @add_start_docstrings_to_model_forward(CLAP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ClapOutput, config_class=ClapConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ClapOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ClapModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused") + + >>> input_text = ["Sound of a dog", "Sound of vaccum cleaner"] + + >>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True) + + >>> outputs = model(**inputs) + >>> logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score + >>> probs = logits_per_audio.softmax(dim=-1) # we can take the softmax to get the label probabilities + ```""" + # Use CLAP model's config for some fields (if specified) instead of those of audio & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + audio_outputs = self.audio_model( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output + audio_embeds = self.audio_projection(audio_embeds) + + text_embeds = text_outputs[1] if not return_dict else text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + audio_embeds = audio_embeds / audio_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale_text = self.logit_scale_t.exp() + logit_scale_audio = self.logit_scale_a.exp() + logits_per_text = torch.matmul(text_embeds, audio_embeds.t()) * logit_scale_text + logits_per_audio = torch.matmul(audio_embeds, text_embeds.t()) * logit_scale_audio + + loss = None + if return_loss: + caption_loss = contrastive_loss(logits_per_text) + audio_loss = contrastive_loss(logits_per_audio.t()) + loss = (caption_loss + audio_loss) / 2.0 + + if not return_dict: + output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs) + return ((loss,) + output) if loss is not None else output + + return ClapOutput( + loss=loss, + logits_per_audio=logits_per_audio, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + audio_embeds=audio_embeds, + text_model_output=text_outputs, + audio_model_output=audio_outputs, + ) + + +@add_start_docstrings( + """ + CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output). + """, + CLAP_START_DOCSTRING, +) +class ClapTextModelWithProjection(ClapPreTrainedModel): + config_class = ClapTextConfig + + def __init__(self, config: ClapTextConfig): + super().__init__(config) + self.text_model = ClapTextModel(config) + self.text_projection = ClapProjectionLayer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.text_model.embeddings.word_embeddings = value + + @add_start_docstrings_to_model_forward(CLAP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ClapTextModelOutput, config_class=ClapTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ClapTextModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ClapTextModelWithProjection + + >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") + >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output + + text_embeds = self.text_projection(pooled_output) + + if not return_dict: + outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return ClapTextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +@add_start_docstrings( + """ + CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output). + """, + CLAP_START_DOCSTRING, +) +class ClapAudioModelWithProjection(ClapPreTrainedModel): + config_class = ClapAudioConfig + main_input_name = "input_features" + + def __init__(self, config: ClapAudioConfig): + super().__init__(config) + self.audio_model = ClapAudioModel(config) + self.audio_projection = ClapProjectionLayer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_model.audio_encoder.patch_embed.proj + + @add_start_docstrings_to_model_forward(CLAP_AUDIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ClapAudioModelOutput, config_class=ClapAudioConfig) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ClapAudioModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import ClapAudioModelWithProjection, ClapProcessor + + >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused") + >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> inputs = processor(audios=audio_sample, return_tensors="pt") + >>> outputs = model(**inputs) + >>> audio_embeds = outputs.audio_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + audio_outputs = self.audio_model( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output + + audio_embeds = self.audio_projection(pooled_output) + + if not return_dict: + outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return ClapAudioModelOutput( + audio_embeds=audio_embeds, + last_hidden_state=audio_outputs.last_hidden_state, + attentions=audio_outputs.attentions, + hidden_states=audio_outputs.hidden_states, + ) + + +__all__ = [ + "ClapModel", + "ClapPreTrainedModel", + "ClapTextModel", + "ClapTextModelWithProjection", + "ClapAudioModel", + "ClapAudioModelWithProjection", +] diff --git a/docs/transformers/build/lib/transformers/models/clap/processing_clap.py b/docs/transformers/build/lib/transformers/models/clap/processing_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..126fc384ebfbfb53a55c08237ed1e951968bed10 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clap/processing_clap.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio/Text processor class for CLAP +""" + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding + + +class ClapProcessor(ProcessorMixin): + r""" + Constructs a CLAP processor which wraps a CLAP feature extractor and a RoBerta tokenizer into a single processor. + + [`ClapProcessor`] offers all the functionalities of [`ClapFeatureExtractor`] and [`RobertaTokenizerFast`]. See the + [`~ClapProcessor.__call__`] and [`~ClapProcessor.decode`] for more information. + + Args: + feature_extractor ([`ClapFeatureExtractor`]): + The audio processor is a required input. + tokenizer ([`RobertaTokenizerFast`]): + The tokenizer is a required input. + """ + + feature_extractor_class = "ClapFeatureExtractor" + tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, text=None, audios=None, return_tensors=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to RobertaTokenizerFast's [`~RobertaTokenizerFast.__call__`] if `text` is not `None` to + encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to + ClapFeatureExtractor's [`~ClapFeatureExtractor.__call__`] if `audios` is not `None`. Please refer to the + docstring of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **audio_features** -- Audio features to be fed to a model. Returned when `audios` is not `None`. + """ + sampling_rate = kwargs.pop("sampling_rate", None) + + if text is None and audios is None: + raise ValueError("You have to specify either text or audios. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + + if audios is not None: + audio_features = self.feature_extractor( + audios, sampling_rate=sampling_rate, return_tensors=return_tensors, **kwargs + ) + + if text is not None and audios is not None: + encoding.update(audio_features) + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**audio_features), tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) + + +__all__ = ["ClapProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/clip/__init__.py b/docs/transformers/build/lib/transformers/models/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18a4db32e9943d78adb459ee9bffeb2222ce4107 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clip/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_clip import * + from .feature_extraction_clip import * + from .image_processing_clip import * + from .image_processing_clip_fast import * + from .modeling_clip import * + from .modeling_flax_clip import * + from .modeling_tf_clip import * + from .processing_clip import * + from .tokenization_clip import * + from .tokenization_clip_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/clip/convert_clip_original_pytorch_to_hf.py b/docs/transformers/build/lib/transformers/models/clip/convert_clip_original_pytorch_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3d88fc1929c30bf71decb229a87c8b4b8b794b31 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/clip/convert_clip_original_pytorch_to_hf.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from clip import load + +from transformers import CLIPConfig, CLIPModel + + +def copy_attn_layer(hf_attn_layer, pt_attn_layer): + q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) + + out_proj_weights = pt_attn_layer.out_proj.weight + out_proj_bias = pt_attn_layer.out_proj.bias + + hf_attn_layer.q_proj.weight.data = q_proj + hf_attn_layer.q_proj.bias.data = q_proj_bias + + hf_attn_layer.k_proj.weight.data = k_proj + hf_attn_layer.k_proj.bias.data = k_proj_bias + + hf_attn_layer.v_proj.weight.data = v_proj + hf_attn_layer.v_proj.bias.data = v_proj_bias + + hf_attn_layer.out_proj.weight = out_proj_weights + hf_attn_layer.out_proj.bias = out_proj_bias + + +def copy_mlp(hf_mlp, pt_mlp): + copy_linear(hf_mlp.fc1, pt_mlp.c_fc) + copy_linear(hf_mlp.fc2, pt_mlp.c_proj) + + +def copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + +def copy_layer(hf_layer, pt_layer): + # copy layer norms + copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) + copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) + + # copy MLP + copy_mlp(hf_layer.mlp, pt_layer.mlp) + + # copy attn + copy_attn_layer(hf_layer.self_attn, pt_layer.attn) + + +def copy_layers(hf_layers, pt_layers): + for hf_layer, pt_layer in zip(hf_layers, pt_layers): + copy_layer(hf_layer, pt_layer) + + +def copy_encoder(hf_encoder, pt_model): + # copy embeds + hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight + hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding + + # copy layer norm + copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) + + # copy hidden layers + copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) + + +def copy_text_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous() + + # copy text encoder + copy_encoder(hf_model.text_model, pt_model) + + +def copy_vison_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous() + + # copy layer norms + copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre) + copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) + + # copy embeds + hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data + hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding + hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data + + # copy encoder + copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) + + +@torch.no_grad() +def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = CLIPConfig.from_pretrained(config_path) + else: + config = CLIPConfig(projection_dim=512, text_config={}, vision_config={}) + + hf_model = CLIPModel(config).eval() + + pt_model, _ = load(checkpoint_path, device="cpu", jit=False) + pt_model = pt_model.eval() + + copy_text_model_and_projection(hf_model, pt_model) + copy_vison_model_and_projection(hf_model, pt_model) + hf_model.logit_scale = pt_model.logit_scale + + # Use `eos_token` so the example is more meaningful + input_ids = torch.tensor( + [ + [config.text_config.bos_token_id] + + list(range(3, 77)) + + [config.text_config.eos_token_id] + + [config.text_config.pad_token_id] + ] + ) + pixel_values = torch.randn(1, 3, 224, 224) + + hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True) + hf_logits_per_image = hf_outputs.logits_per_image + hf_logits_per_text = hf_outputs.logits_per_text + pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) + + assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) + assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3) + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + args = parser.parse_args() + + convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) diff --git a/old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl b/old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..50fb0eef6b9a3f56019ea6ee4f036a692346c409 --- /dev/null +++ b/old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0f6360a5bc18603afd8cd64d3d7b6e9b5b55b204a53031ce3570be5f01aa05b +size 16739995 diff --git a/old/dataset_10k_train.jsonl b/old/dataset_10k_train.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..50fb0eef6b9a3f56019ea6ee4f036a692346c409 --- /dev/null +++ b/old/dataset_10k_train.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0f6360a5bc18603afd8cd64d3d7b6e9b5b55b204a53031ce3570be5f01aa05b +size 16739995 diff --git a/seamless_interaction/assets/banner.gif b/seamless_interaction/assets/banner.gif new file mode 100644 index 0000000000000000000000000000000000000000..f02e52988d4bebe998cdba2b8d18c0e70811ef77 --- /dev/null +++ b/seamless_interaction/assets/banner.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b47141b5f3018e8387671dfe858090c810438902c6e6d72a7022c01e262b08c +size 36172171 diff --git a/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc b/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9376ec01ac995df4616cca12c132d04e8743b43a Binary files /dev/null and b/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc differ diff --git a/swift/llm/template/template/__init__.py b/swift/llm/template/template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fce57ff22b23d164c2b552049700a08ea9fa221a --- /dev/null +++ b/swift/llm/template/template/__init__.py @@ -0,0 +1,2 @@ +from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm, + minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi) diff --git a/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc b/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf97c89b4388d6987fc5b1e43008ccf21135d871 Binary files /dev/null and b/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/deepseek.cpython-310.pyc b/swift/llm/template/template/__pycache__/deepseek.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ea125bcadb2e6d15d4dc358cbc3088c6107683 Binary files /dev/null and b/swift/llm/template/template/__pycache__/deepseek.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc b/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..781e348228008d394eb037e150bf95ae1fa1a393 Binary files /dev/null and b/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc b/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08b36b7a4246347c6325160eb914e03b80b474d4 Binary files /dev/null and b/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/glm.cpython-310.pyc b/swift/llm/template/template/__pycache__/glm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf392e2e0fe26e667b19228465c63ca4c3254b7 Binary files /dev/null and b/swift/llm/template/template/__pycache__/glm.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/idefics3.cpython-310.pyc b/swift/llm/template/template/__pycache__/idefics3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db1455e999a7738ed12feee6c6e8e544bfb0c1ac Binary files /dev/null and b/swift/llm/template/template/__pycache__/idefics3.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/internlm.cpython-310.pyc b/swift/llm/template/template/__pycache__/internlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c2e3fe7cb19764ad3c7b7bb9454d5e83a2c5cd6 Binary files /dev/null and b/swift/llm/template/template/__pycache__/internlm.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc b/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e50fb1c25d12f4839ad413077c4768d813435962 Binary files /dev/null and b/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/llama.cpython-310.pyc b/swift/llm/template/template/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de4673306975c35281674f5649222baeee8047e3 Binary files /dev/null and b/swift/llm/template/template/__pycache__/llama.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/llava.cpython-310.pyc b/swift/llm/template/template/__pycache__/llava.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8336d8ca16feb550b71265b203d41d3c6f3747e2 Binary files /dev/null and b/swift/llm/template/template/__pycache__/llava.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/llm.cpython-310.pyc b/swift/llm/template/template/__pycache__/llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9796394798d0be20ea1dc61d51b9fbf412673827 Binary files /dev/null and b/swift/llm/template/template/__pycache__/llm.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/megrez.cpython-310.pyc b/swift/llm/template/template/__pycache__/megrez.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d78cb5fd79967f7b85eefdddd9d7c9eb6a60940d Binary files /dev/null and b/swift/llm/template/template/__pycache__/megrez.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/microsoft.cpython-310.pyc b/swift/llm/template/template/__pycache__/microsoft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fd94dd2228ff333c0115bb12b5764ecc949c7cc Binary files /dev/null and b/swift/llm/template/template/__pycache__/microsoft.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc b/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38fa35ae72a511c40a8e48894d525adbe9e5521b Binary files /dev/null and b/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/minimax.cpython-310.pyc b/swift/llm/template/template/__pycache__/minimax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fa21bc72a9a7dbb0e91c0989c385675273f1ff Binary files /dev/null and b/swift/llm/template/template/__pycache__/minimax.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/mistral.cpython-310.pyc b/swift/llm/template/template/__pycache__/mistral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b335d5ac9500378b50daa43e565577ecc2a7792 Binary files /dev/null and b/swift/llm/template/template/__pycache__/mistral.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/molmo.cpython-310.pyc b/swift/llm/template/template/__pycache__/molmo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c00ad0b7b7703f9dc92143501fef3c688a0dd51 Binary files /dev/null and b/swift/llm/template/template/__pycache__/molmo.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/moonshot.cpython-310.pyc b/swift/llm/template/template/__pycache__/moonshot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..754acaf02e244dacc0b7eab63d1856056421dece Binary files /dev/null and b/swift/llm/template/template/__pycache__/moonshot.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/mplug.cpython-310.pyc b/swift/llm/template/template/__pycache__/mplug.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b96e771cdcfd32779dcfa8d3e820280a934cdd Binary files /dev/null and b/swift/llm/template/template/__pycache__/mplug.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/openbuddy.cpython-310.pyc b/swift/llm/template/template/__pycache__/openbuddy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059d63247173d21db5b4079a11480a03508ab08e Binary files /dev/null and b/swift/llm/template/template/__pycache__/openbuddy.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc b/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efdada53ffa06e8df1ac13a3b8e08aed8a5dcb6e Binary files /dev/null and b/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/qwen.cpython-310.pyc b/swift/llm/template/template/__pycache__/qwen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89b986a22cc81188ac6e9cce178a8156b95b8fa3 Binary files /dev/null and b/swift/llm/template/template/__pycache__/qwen.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc b/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfebc77af8847cfd231ec8203f9ea2be473b7371 Binary files /dev/null and b/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/utils.cpython-310.pyc b/swift/llm/template/template/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7511de12479cd316f4a718fb5ab1f904b2a363e8 Binary files /dev/null and b/swift/llm/template/template/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/valley.cpython-310.pyc b/swift/llm/template/template/__pycache__/valley.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f59dc987d14735777e7b7f7e5b0d896b887c4a28 Binary files /dev/null and b/swift/llm/template/template/__pycache__/valley.cpython-310.pyc differ diff --git a/swift/llm/template/template/__pycache__/yi.cpython-310.pyc b/swift/llm/template/template/__pycache__/yi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed83c9dc462bb034a00e3a77a9159794beef20a Binary files /dev/null and b/swift/llm/template/template/__pycache__/yi.cpython-310.pyc differ diff --git a/swift/llm/template/template/deepseek.py b/swift/llm/template/template/deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..cda07ecf93476c9a7edd610873740d48ee7e7352 --- /dev/null +++ b/swift/llm/template/template/deepseek.py @@ -0,0 +1,315 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from swift.utils import get_env_args +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Prompt, findall + + +@dataclass +class DeepseekTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: [['bos_token_id']]) + prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']]) + suffix: Prompt = field(default_factory=lambda: [['eos_token_id']]) + system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n']) + + +register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, )) + +register_template( + TemplateMeta( + LLMTemplateType.deepseek_coder, + prefix=['{{SYSTEM}}'], + prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'], + chat_sep=['\n<|EOT|>\n'], + suffix=['\n<|EOT|>'], + stop_words=['<|EOT|>'], + default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, ' + 'developed by Deepseek Company, and you only answer questions related to computer science. ' + 'For politically sensitive questions, security and privacy issues, ' + 'and other non-computer science questions, you will refuse to answer\n'))) + + +class DeepseekVLTemplate(Template): + image_placeholder = [''] + skip_prompt = False + use_model = True + placeholder_tokens = [''] + + image_token_num_per_image: int = 576 + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + is_janus = getattr(self, 'is_janus', False) + + encoded = super()._encode(inputs) + images = inputs.images + processor = self.processor + input_ids, labels = encoded['input_ids'], encoded['labels'] + + if not inputs.generate_mode: # understanding task + idx_list = findall(input_ids, processor.image_id) # '' + new_input_ids, new_labels = [], [] + lo = 0 + for hi in idx_list: + new_input_ids += input_ids[lo:hi] + if labels is not None: + new_labels += labels[lo:hi] + image_tokens = [processor.image_id] * processor.num_image_tokens + if is_janus: + image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id] + new_input_ids += image_tokens + new_labels += [-100] * len(image_tokens) + lo = hi + 1 + new_input_ids += input_ids[lo:] + if labels is not None: + new_labels += labels[lo:] + else: + new_labels = None + if is_janus: + from janus.models.processing_vlm import VLChatProcessorOutput + else: + from deepseek_vl.models.processing_vlm import VLChatProcessorOutput + + images_outputs = processor.image_processor(images, return_tensors='pt') + output = VLChatProcessorOutput( + sft_format=None, + input_ids=torch.tensor(new_input_ids), + pixel_values=images_outputs.pixel_values, + num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list))) + encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels} + return encoded + + else: # image generation task + if self.is_training: + raise NotImplementedError('Only support the inference of generation of Janus series models.') + sft_format = self.tokenizer.decode(input_ids) + prompt = sft_format + processor.image_start_tag + input_ids = processor.tokenizer.encode(prompt) + input_ids = torch.LongTensor(input_ids) + + encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode} + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + if not inputs.get('generate_mode'): + inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype) + inputs_embeds = model.prepare_inputs_embeds(**inputs) + return {'inputs_embeds': inputs_embeds} + else: + return inputs + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + gene_img_list = [b.get('generate_mode') for b in batch] + if all(gene_img_list): + generate_mode = True + elif not any(gene_img_list): + generate_mode = False + else: + raise NotImplementedError('Do not support understanding and image generation tasks in one batch.') + + if not generate_mode: + output = self.fetch_inputs(batch, ['output'])['output'] + batched_output = dict(self.processor.batchify(output)) + res = super()._data_collator(batch, padding_to=padding_to) + return {**batched_output, **res} + else: + res = super()._data_collator(batch, padding_to=padding_to) + res['generate_mode'] = generate_mode + return res + + def generate(self, model, *args, **kwargs): + if not kwargs.get('generate_mode'): + return super().generate(model, *args, **kwargs) + + else: + # generate how many number of images for each prompt, it is named parallel_size in the author's code + parallel_size = kwargs['generation_config'].num_return_sequences + temperature = kwargs['generation_config'].temperature + cfg_weight = get_env_args('cfg_weight', float, 5.0) + + input_ids = kwargs['input_ids'] # [bsz, max_input_token_num] + bsz, max_input_token_num = input_ids.shape + tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num), + dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num] + for i in range(parallel_size * 2): + tokens[:, i, :] = input_ids + if i % 2 != 0: + tokens[:, i, 1:-1] = self.processor.pad_id + + inputs_embeds = model.language_model.get_input_embeddings()( + tokens) # [bsz, parallel_size*2, max_input_token_num, 2048] + + generated_tokens = torch.zeros( + (bsz, parallel_size, self.image_token_num_per_image), + dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens + + # set the first two dimensions into one dimension for batch size + inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1) + generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image) + + for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way + outputs = model.language_model.model( + inputs_embeds=inputs_embeds, + use_cache=True, + past_key_values=outputs.past_key_values if i != 0 else None) + hidden_states = outputs.last_hidden_state + + logits = self.model.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image] + + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048] + inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048] + + # no need to reset the original first two dimensions, waiting for the update of the upper layer + # inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1) + # generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image) + + return {'sequences': generated_tokens} + + def decode(self, generate_ids: List[int], **kwargs) -> Any: + if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode: + return super().decode(generate_ids, **kwargs) + else: + img_size = get_env_args('img_size', int, 384) + patch_size = 16 + + num_to_decode = 1 # for now, generate_ids is a 1D list + + generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image] + + dec = self.model.gen_vision_model.decode_code( + generate_ids.to(dtype=torch.int), + shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size]) + dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3] + + dec = np.clip((dec + 1) / 2 * 255, 0, 255) + + visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8) + visual_img[:, :, :] = dec + + img_list = [] + for i in range(num_to_decode): + cur_img = Image.fromarray(visual_img[i]) + img_list.append({'type': 'image', 'image': cur_img}) + return img_list + + +@dataclass +class DeepseekVLTemplateMeta(DeepseekTemplateMeta): + default_system: Optional[str] = ('You are a helpful language and vision assistant. ' + 'You are able to understand the visual content that the user provides, ' + 'and assist the user with a variety of tasks using natural language.') + + +register_template(DeepseekVLTemplateMeta( + MLLMTemplateType.deepseek_vl, + template_cls=DeepseekVLTemplate, +)) + + +class DeepseekJanus(DeepseekVLTemplate): + is_janus = True + image_placeholder = ['\n'] + + +register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus)) + + +@dataclass +class DeepseekV2_5TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}']) + prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>']) + suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>']) + + +register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5)) + + +class DeepseekR1Template(Template): + + def _swift_encode(self, inputs: StdTemplateInputs): + if not self.is_training: + for message in inputs.messages: + if message['role'] == 'assistant' and isinstance(message['content'], str): + message['content'] = message['content'].split('')[-1] + return super()._swift_encode(inputs) + + +register_template( + DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='\n')) + + +class DeepseekVL2Template(DeepseekVLTemplate): + image_placeholder = ['\n'] + placeholder_tokens = [''] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput + encoded = Template._encode(self, inputs) + images = inputs.images + processor = self.processor + input_ids, labels = encoded['input_ids'], encoded['labels'] + images_seq_mask = [False] * len(input_ids) + idx_list = findall(input_ids, processor.image_token_id) # '' + _, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images( + '' * len(images), images, cropping=len(images) <= 2) + new_num_tokens = 0 + for idx, n_image_tokens in zip(idx_list, num_image_tokens): + image_tokens = [processor.image_token_id] * n_image_tokens + input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:] + if labels is not None: + labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:] + images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:] + new_num_tokens += n_image_tokens - 1 + + output = VLChatProcessorOutput( + sft_format=None, + input_ids=torch.tensor(input_ids), + target_ids=torch.tensor(input_ids), + images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)), + images_seq_mask=torch.tensor(images_seq_mask), + images_spatial_crop=torch.tensor(images_spatial_crop), + num_image_tokens=num_image_tokens) + output.images = output.images.to(dtype=self.model_info.torch_dtype) + encoded = {'output': output, 'input_ids': input_ids, 'labels': labels} + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool) + inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long) + inputs_embeds = model.prepare_inputs_embeds(**inputs) + return {'inputs_embeds': inputs_embeds} + + +register_template( + DeepseekV2_5TemplateMeta( + MLLMTemplateType.deepseek_vl2, + prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'], + template_cls=DeepseekVL2Template, + )) + +register_template( + DeepseekVLTemplateMeta( + MLLMTemplateType.deepseek_janus_pro, + prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'], + template_cls=DeepseekJanus)) diff --git a/swift/llm/template/template/emu3.py b/swift/llm/template/template/emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..47cf7d421c3aef61027caa07913032475a44bed2 --- /dev/null +++ b/swift/llm/template/template/emu3.py @@ -0,0 +1,191 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import random +from typing import Any, Dict, List, Optional + +import torch +from PIL import Image + +from swift.utils import get_device +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import register_template +from ..template_inputs import StdTemplateInputs +from ..template_meta import TemplateMeta +from ..utils import findall +from .utils import DEFAULT_SYSTEM, EmptyTemplateMeta + + +class Emu3GenTemplate(Template): + + NULL_PROMPT_PROB = 0.1 + COOKBOOK_SIZE = 32768 + CFG_SCALE = os.environ.get('CFG_SCALE', 3.0) + GENERATION_RATIO = os.environ.get('GENERATION_RATIO', '1:1') + NEGATIVE_PROMPT = os.environ.get( + 'NEGATIVE_PROMPT', + 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, ' + 'worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.bov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=0))[0] + self.eov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=self.COOKBOOK_SIZE + - 1))[0] + self.h, self.w = self.processor.calculate_generate_size(self.GENERATION_RATIO, self.processor.image_area, + self.processor.vision_tokenizer.spatial_scale_factor) + self.skip_prompt = False + self.apply_loss_on_only_vision = True + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + if self.is_training: + p_prob = random.random() + if p_prob < self.NULL_PROMPT_PROB: + prompt = '' + else: + prompt = inputs.to_history()['response'] + image = self.smart_resize(inputs.images[0].convert('RGB')) + with torch.no_grad(): + image = self.processor.image_processor( + image, return_tensors='pt')['pixel_values'].to(device=self.processor.vision_tokenizer.device) + image_token_ids = self.processor.vision_tokenizer.encode(image).squeeze(0) + encoded = self._process_prompt_train(prompt, image_token_ids) + else: + prompt = inputs.to_history()['query'] + encoded = self._process_prompt_test(prompt) + encoded = {key: encoded[key][0] for key in encoded.keys()} # [1, L] -> [L] + + return encoded + + def _process_prompt_train(self, raw_prompt, image_token_ids): + image_prompt = self.format_image_prompt(image_token_ids) + prompt = self.tokenizer.bos_token + raw_prompt + image_prompt + sample = self.tokenizer(prompt, padding='max_length', return_token_type_ids=False) + labels = torch.tensor(sample['input_ids']) + if self.apply_loss_on_only_vision: + labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, -100) + sample['labels'] = labels.tolist() + return sample + + def _process_prompt_test(self, raw_prompt): + # for supporting multi inputs, use list instead of single string + if isinstance(raw_prompt, str): + raw_prompt = [raw_prompt] + prompt_list = [] + size_list = [] + for text_prompt in raw_prompt: + prompt = self.processor.tokenizer.bos_token + image_prompt = ( + self.processor.tokenizer.boi_token + self.processor.prefix_template.format(H=self.h, W=self.w) + + self.processor.tokenizer.img_token) + prompt += (text_prompt + image_prompt) + prompt_list.append(prompt) + size_list.append([self.h, self.w]) + prompt_list = self.tokenizer(prompt_list, padding='longest', return_token_type_ids=False) + return prompt_list + + def prepare_for_output(self, output: str) -> str: + return output + + def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]: + from transformers import UnbatchedClassifierFreeGuidanceLogitsProcessor + from transformers import PrefixConstrainedLogitsProcessor + from transformers import LogitsProcessorList + + negative_prompt = self.NEGATIVE_PROMPT + neg_inputs = self._process_prompt_test(negative_prompt) + neg_inputs = {key: torch.tensor(val) for key, val in neg_inputs.items()} + batch_size = generate_kwargs['input_ids'].shape[0] + h = torch.tensor([self.h] * batch_size) + w = torch.tensor([self.w] * batch_size) + + constrained_fn = self.processor.build_prefix_constrained_fn(h, w) + logits_processor = LogitsProcessorList([ + UnbatchedClassifierFreeGuidanceLogitsProcessor( + self.CFG_SCALE, + model, + unconditional_ids=neg_inputs['input_ids'].to(get_device()), + ), + PrefixConstrainedLogitsProcessor( + constrained_fn, + num_beams=1, + ), + ]) + res = super().prepare_generate_kwargs(generate_kwargs, model=model) + res['logits_processor'] = logits_processor + return res + + def decode(self, generate_ids: List[int], **kwargs) -> Any: + mm_list = self.processor.decode(generate_ids) + for im in mm_list: + if not isinstance(im, Image.Image): + continue + return [{'type': 'image', 'image': im}] + + def to_imgstr(self, image_tokens): + image_token_str = [[self.processor.visual_template[0].format(token_id=token_id) for token_id in token_row] + for token_row in image_tokens] + image_row_str = [''.join(token_row) for token_row in image_token_str] + imgstr = self.tokenizer.eol_token.join(image_row_str) + return imgstr + + def format_image_prompt(self, image_tokens): + h, w = image_tokens.shape + imgstr = self.to_imgstr(image_tokens) + image_prompt = ( + self.tokenizer.boi_token + f'{h}*{w}' + self.tokenizer.img_token + imgstr + self.tokenizer.eol_token + + self.tokenizer.eof_token + self.tokenizer.eoi_token) + return image_prompt + + def smart_resize(self, image): + w, h = image.size + current_area = h * w + target_ratio = (self.processor.image_area / current_area)**0.5 + th = int(round(h * target_ratio)) + tw = int(round(w * target_ratio)) + image = image.resize((tw, th)) + return image + + +register_template(EmptyTemplateMeta( + MLLMTemplateType.emu3_gen, + template_cls=Emu3GenTemplate, +)) + + +class Emu3ChatTemplate(Template): + system = 'You are a helpful assistant.' + image_placeholder = ['<|image token|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + # image + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + image_tokens = self.processor.tokenize_image(images) + image_prompts = [] + idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder)) + # Create image prompts + for i in range(len(images)): + h, w = image_tokens[i].shape + imgstr = self.processor.to_imgstr(image_tokens[i]) + image_prompt = ( + self.tokenizer.boi_token + self.processor.prefix_template.format(H=h, W=w) + self.tokenizer.img_token + + imgstr + self.tokenizer.eol_token + self.tokenizer.eof_token + self.tokenizer.eoi_token) + image_prompts.append(self.tokenizer.encode(image_prompt)) + + # Insert image tokens into input_ids + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i]) + return {'input_ids': input_ids, 'labels': labels} + + +register_template( + TemplateMeta( + MLLMTemplateType.emu3_chat, + prefix=[['bos_token_id'], '{{SYSTEM}}'], + prompt=[' User: {{QUERY}}. Assistant:'], + chat_sep=[['eos_token_id']], + suffix=[['eos_token_id']], + default_system=DEFAULT_SYSTEM, + template_cls=Emu3ChatTemplate)) diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py new file mode 100644 index 0000000000000000000000000000000000000000..9feae85df073ae9feb3023d0097cc0848e7dc211 --- /dev/null +++ b/swift/llm/template/template/glm.py @@ -0,0 +1,293 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import torch + +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, Word, findall +from ..vision_utils import load_batch, load_video_cogvlm2 + + +@dataclass +class GLMTemplateMeta(TemplateMeta): + auto_add_bos: bool = True + + +class GLM4Template(Template): + + def _swift_encode(self, inputs: StdTemplateInputs): + res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs) + for i, res_context in enumerate(res_context_list): + # The last round or is tool_call. + if isinstance(res_context, str) and res_context.endswith('<|assistant|>\n') and ( + i + 1 >= len(res_context_list) or '<|observation|>' in res_context_list[i + 1]): + res_context_list[i] = res_context_list[i][:-len('\n')] + return res_context_list, loss_scale_list, answer_len + + def decode(self, *args, **kwargs): + response = super().decode(*args, **kwargs) + return response.lstrip('\n') + + +class GLM4_0414Template(GLM4Template): + + def _swift_encode(self, inputs: StdTemplateInputs): + if not self.is_training: + for message in inputs.messages: + if message['role'] == 'assistant' and isinstance(message['content'], str): + message['content'] = message['content'].split('')[-1].strip() + return super()._swift_encode(inputs) + + +register_template( + GLMTemplateMeta( + LLMTemplateType.chatglm2, + prefix=['{{SYSTEM}}'], + prompt=['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], + chat_sep=['\n\n'])) + + +@dataclass +class GLM4TemplateMeta(GLMTemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n']) + chat_sep: Optional[Prompt] = field(default_factory=list) + suffix: Prompt = field(default_factory=lambda: ['<|user|>']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}']) + + agent_template: str = 'glm4' + stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>']) + + +@dataclass +class GLM4_0414TemplateMeta(GLM4TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['[gMASK]']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<|system|>\n{{SYSTEM}}']) + agent_template: str = 'glm4_0414' + + +class GLM4VTemplate(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [[-100]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, -100) + if idx_list: + idx = idx_list[0] + image = inputs.images[0] + placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>' + placeholder_id = self.processor.encode(placeholder, add_special_tokens=False) + input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:]) + if labels is not None: + labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:]) + messages = inputs.messages + messages[0]['image'] = image + inputs2: Dict[str, Any] = self.processor.apply_chat_template(messages, return_dict=True) + encoded['images'] = inputs2['images'] + encoded['input_ids'] = input_ids + encoded['labels'] = labels + encoded['position_ids'] = list(range(len(input_ids))) + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = [b['images'] for b in batch if 'images' in b] + if images: + res['images'] = torch.concat(images) + return res + + +register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>'])) + +register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template)) + +register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template)) + +glm4z1rumination_system = ( + '你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。' + '今年是 2025 年。\n\n' + '<核心要求>\n' + '- 首先分解用户请求,得到包含多个子要求的列表\n' + '- 制定初始研究计划\n' + '- 进行多轮迭代搜索和页面浏览(at least 10 function calls):\n' + ' * 根据已获得的信息调整研究计划和关键词\n' + ' * 打开页面阅读,从发现的内容中识别新的关键概念/名词\n' + ' * 从搜索结果中提取新的关键词继续搜索\n' + ' * 访问并仔细阅读相关页面,识别新的关键概念/名词\n\n' + '<重要配置>\n' + '- 采用语言\n' + ' * 搜索关键词:英语\n' + ' * 思考:英语\n\n' + '<可调用的工具列表>\n\n' + '[{"name": "search", "description": "Execute a search query and return search results. ' + 'Use this function when you need to find information about a specific topic.", ' + '"parameters": {"type": "object", "properties": {"query": {"type": "string", ' + '"description": "Search query string, use English words unless it is a proper name in Chinese"}}, ' + '"required": ["query"], "additionalProperties": false}}, ' + '{"name": "click", "description": "Click a link in the search results and navigate to the corresponding page. ' + 'Use this function when you need to view detailed content of a specific search result.", ' + '"parameters": {"type": "object", "properties": {"link_id": {"type": "integer", ' + '"description": "The link ID to click (from the sequence number in search results)"}}, ' + '"required": ["link_id"], "additionalProperties": false}}, ' + '{"name": "open", "description": "Open a specific website. Get content from any website with its URL.", ' + '"parameters": {"type": "object", "properties": {"url": {"type": "string", ' + '"description": "The target website URL or domain"}}, "required": ["url"], "additionalProperties": false}}, ' + '{"name": "finish", "description": "Finish the task. ' + 'Use this function when you have found the information you need.", ' + '"parameters": {"type": "object", "properties": {}, "additionalProperties": false}}]') + +register_template( + GLM4_0414TemplateMeta( + LLMTemplateType.glm4_z1_rumination, template_cls=GLM4_0414Template, default_system=glm4z1rumination_system)) + +codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。' + +register_template(GLM4TemplateMeta(LLMTemplateType.codegeex4, default_system=codegeex4_system)) + +register_template( + TemplateMeta( + LLMTemplateType.longwriter_llama, ['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'], + system_prefix=['<>\n{{SYSTEM}}\n<>\n\n'])) + + +class CogTemplate(Template): + placeholder_tokens = ['<|reserved_special_token_0|>'] + + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + return [] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + model = self.model + image = inputs.images or [] + history_inputs = inputs.to_history() + inputs2 = model.build_conversation_input_ids( + self.processor, query=history_inputs['query'], history=history_inputs['history'], images=image) + image_token_len = inputs2['token_type_ids'].sum().item() + input_ids = encoded['input_ids'] + labels = encoded['labels'] + encoded['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:]) + encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * image_token_len + input_ids[1:] + if labels is not None: + encoded['labels'] = labels[:1] + [-100] * image_token_len + labels[1:] + if len(image) > 0: + encoded['images'] = [[img.to(dtype=self.model_info.torch_dtype)] for img in inputs2['images']] + if 'cross_images' in inputs2: + # is cogagent + encoded['cross_images'] = [[cross_img.to(dtype=self.model_info.torch_dtype)] + for cross_img in inputs2['cross_images']] + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + keys = ['images', 'cross_images'] + for key in keys: + if key in batch[0]: + res[key] = [b[key][0] for b in batch] + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.cogagent_chat, + prefix=[''], + prompt=[' [INST] {{QUERY}} [/INST] '], + chat_sep=[], + suffix=[''], + template_cls=CogTemplate, + )) + +register_template( + TemplateMeta( + MLLMTemplateType.cogagent_vqa, + prefix=[''], + prompt=['Question: {{QUERY}} Answer:'], + chat_sep=None, + suffix=[''], + template_cls=CogTemplate)) + + +@dataclass +class CogVLMTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: [['bos_token_id']]) + prompt: Prompt = field(default_factory=lambda: ['Question: {{QUERY}} Answer:']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n']) + + +register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm, template_cls=CogTemplate)) + +register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm2, template_cls=CogTemplate)) + + +class Cog2VideoTemplate(CogTemplate): + use_model = True + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + model = self.model + encoded = super(CogTemplate, self)._encode(inputs) + videos_path = inputs.videos or [] + video = load_batch(videos_path, load_video_cogvlm2) + history_inputs = inputs.to_history() + inputs2 = model.build_conversation_input_ids( + self.processor, + query=history_inputs['query'], + history=history_inputs['history'], + images=video, + template_version='chat') + video_token_len = inputs2['token_type_ids'].sum().item() + input_ids = encoded['input_ids'] + labels = encoded['labels'] + encoded['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:]) + encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * video_token_len + input_ids[1:] + if labels is not None: + encoded['labels'] = labels[:1] + [-100] * video_token_len + labels[1:] + if len(video) > 0: + dtype = model.dtype + encoded['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']] + return encoded + + +register_template(CogVLMTemplateMeta( + MLLMTemplateType.cogvlm2_video, + template_cls=Cog2VideoTemplate, +)) + + +class GLMEdgeVTemplate(Template): + placeholder_tokens = ['<|begin_of_image|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return ['<|begin_of_image|>' * 578] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + if images: + encoded['pixel_values'] = torch.tensor(self.processor(images).pixel_values) + return encoded + + +register_template( + GLM4TemplateMeta( + MLLMTemplateType.glm_edge_v, + prompt=['<|user|>\\n{{QUERY}}\\n<|assistant|>\\n'], + chat_sep=['\\n'], + system_prefix=['<|system|>\\n{{SYSTEM}}\\n'], + suffix=['<|endoftext|>'], + template_cls=GLMEdgeVTemplate, + )) diff --git a/swift/llm/template/template/idefics3.py b/swift/llm/template/template/idefics3.py new file mode 100644 index 0000000000000000000000000000000000000000..05497db676b20bbfabab81ab8acd8e6ae446b09b --- /dev/null +++ b/swift/llm/template/template/idefics3.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import align_image_inputs + + +class Idefics3Template(Template): + placeholder_tokens = [''] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images or [] + processor = self.processor + prompt = self.processor.decode(encoded['input_ids']) + if images: + image_inputs = processor(text=prompt, images=images, return_tensors='pt', add_special_tokens=False) + image_token = 128257 # + encoded['input_ids'], encoded['labels'] = align_image_inputs(encoded['input_ids'], encoded['labels'], + image_inputs['input_ids'][0], image_token) + encoded['pixel_values'] = image_inputs['pixel_values'] + return encoded + + +register_template( + TemplateMeta( + MLLMTemplateType.idefics3, + prefix=['<|begin_of_text|>'], + prompt=['User:{{QUERY}}\nAssistant:'], + chat_sep=['\n'], + suffix=[''], + system_prefix=['System:{{SYSTEM}}\n'], + template_cls=Idefics3Template, + )) diff --git a/swift/llm/template/template/internlm.py b/swift/llm/template/template/internlm.py new file mode 100644 index 0000000000000000000000000000000000000000..fb4e9682fa7f0360fab62d2d918c2b7610f8faa1 --- /dev/null +++ b/swift/llm/template/template/internlm.py @@ -0,0 +1,195 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import torch +from PIL import Image +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from swift.utils import get_env_args +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType, RMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, Word +from ..vision_utils import load_file +from .utils import ChatmlTemplateMeta + +INTERNLM_SYSTEM = ( + 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). ' + 'It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen ' + 'by the user such as English and 中文.') + +register_template( + TemplateMeta( + LLMTemplateType.internlm, + prefix=[''], + prompt=['<|User|>:{{QUERY}}\n<|Bot|>:'], + chat_sep=['\n'], + suffix=[''], + default_system=INTERNLM_SYSTEM, + system_prefix=['<|System|>:{{SYSTEM}}\n'])) + +register_template(ChatmlTemplateMeta(LLMTemplateType.internlm2, default_system=INTERNLM_SYSTEM)) + +register_template(ChatmlTemplateMeta(RMTemplateType.internlm2_reward, suffix=['<|im_end|>\n<|reward|>'])) + + +class InternLMXComposer2Template(Template): + image_placeholder = [''] + version = 'v2' + skip_prompt = False + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'video': + inputs.images.insert(inputs.image_idx, inputs.videos[index]) + inputs.image_idx += 1 + return self.image_placeholder + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + model = self.model + encoded = super()._encode(inputs) + images = inputs.images or [] + + if self.version == 'v2.5': + hd_num = 24 + if len(images) > 1: + hd_num = 6 + hd_num = get_env_args('hd_num', int, hd_num) + images_origin = images + images = [] + for image in images_origin: + if isinstance(image, Image.Image): + Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', model.model_dir) + images.append(Image_transform(image, hd_num=hd_num)) + else: + load_video = get_class_from_dynamic_module('ixc_utils.load_video', model.model_dir) + frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', model.model_dir) + Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', model.model_dir) + image = load_video(load_file(image)) + image = frame2img(image, model.font) + images.append(Video_transform(image, hd_num=hd_num)) + elif self.version == 'v2-4khd': + hd_num = 55 + hd_num = get_env_args('hd_num', int, hd_num) + HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', model.model_dir) + images = [HD_transform(image, hd_num=hd_num) for image in images] + images = [model.vis_processor(image).to(model.dtype) for image in images] + encoded['images'] = images + return encoded + + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + batch_size = len(inputs['input_ids']) + res = [] + im_mask = [] + length = inputs['length'] + for i in range(batch_size): + input_ids = inputs['input_ids'][i].tolist()[:length[i]] + input_ids.append(2) # add dummy + labels = inputs.get('labels') + if labels is not None: + labels = labels[i].tolist()[:length[i]] + labels.append(2) + else: + labels = [] + images = inputs['images'][i] + res_inputs_embeds = [] + res_labels = [] + wrap_im_mask = [] + pre_i, i, idx = 0, 0, 0 + device = model.device + internlm2_model = model.model + if not hasattr(internlm2_model, 'tok_embeddings'): + internlm2_model = internlm2_model.model + tok_embeddings = internlm2_model.tok_embeddings + if len(images) > 0: + images = torch.concat([model.img2emb(image[None])[0] for image in images], dim=0) + add_bos = False + while i < len(input_ids): + if input_ids[i] == 2: # replace_token + res_input_ids = torch.tensor(([1] if add_bos else []) + input_ids[pre_i:i], device=device) + if not add_bos and self.version != 'v2.5': + add_bos = True + res_inputs_embeds.append(tok_embeddings(res_input_ids[None])[0]) + wrap_im_mask += [0] * len(res_input_ids) + res_labels += ([-100] if add_bos else []) + labels[pre_i:i] + if len(images) > 0 and idx < images.shape[0]: + res_inputs_embeds.append(images[idx].to(device)) + wrap_im_mask += [1] * images.shape[1] + res_labels += [-100] * images.shape[1] + idx += 1 + i += 1 + pre_i = i + continue + i += 1 + if len(labels) == 0: + res_labels = None + im_mask.append(torch.tensor(wrap_im_mask, dtype=torch.bool, device=device)) + res.append({'inputs_embeds': torch.concat(res_inputs_embeds, dim=0), 'labels': res_labels}) + res = Template._data_collator(self, res) + res['im_mask'] = self._pad_sequence(im_mask, 0) + return res + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + res['length'] = [len(b['input_ids']) for b in batch] + res.update(self.fetch_inputs(batch, ['images'])) + return res + + +@dataclass +class Xcomposer2TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['']) + prompt: Prompt = field( + default_factory=lambda: ['[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['[UNUSED_TOKEN_145]\n']) + suffix: Prompt = field(default_factory=lambda: ['[UNUSED_TOKEN_145]']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n']) + stop_words: List[Word] = field(default_factory=lambda: ['<|im_end|>']) + + +register_template( + Xcomposer2TemplateMeta( + MLLMTemplateType.xcomposer2, + template_cls=InternLMXComposer2Template, + default_system=('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n' + '- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by ' + 'Shanghai AI Laboratory (上海人工智能实验室). ' + 'It is designed to be helpful, honest, and harmless.\n' + '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen ' + 'by the user such as English and 中文.'), + )) + + +class InternLMXComposer2_5Template(InternLMXComposer2Template): + system = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n' + '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model ' + 'that is developed by Shanghai AI Laboratory (上海人工智能实验室). ' + 'It is designed to be helpful, honest, and harmless.\n' + '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen ' + 'by the user such as English and 中文.\n' + '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively ' + 'based on the provided image.') + version = 'v2.5' + + +class InternLMXComposer2_4khdTemplate(InternLMXComposer2Template): + version = 'v2-4khd' + + +register_template( + Xcomposer2TemplateMeta( + MLLMTemplateType.xcomposer2_5, + template_cls=InternLMXComposer2_5Template, + default_system=InternLMXComposer2_5Template.system)) + +register_template( + Xcomposer2TemplateMeta( + MLLMTemplateType.xcomposer2_4khd, + template_cls=InternLMXComposer2_4khdTemplate, + default_system=InternLMXComposer2_5Template.system)) diff --git a/swift/llm/template/template/internvl.py b/swift/llm/template/template/internvl.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9973ad7974e7c228b4554831e7e5bc3fbd1660 --- /dev/null +++ b/swift/llm/template/template/internvl.py @@ -0,0 +1,168 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from typing import Any, Dict, List, Literal + +import torch +from torch import nn + +from swift.utils import get_env_args, is_deepspeed_enabled +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, findall +from ..vision_utils import load_video_internvl, transform_image +from .microsoft import Phi3TemplateMeta +from .utils import ChatmlTemplateMeta + + +class InternvlTemplate(Template): + skip_prompt = False + num_image_token = 256 + placeholder_tokens = [''] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if self.mode == 'vllm': + image_context = ['\n'] + else: + image_context = ['', [-100], '\n'] + return image_context + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + idx_list = findall(input_ids, -100) + pixel_values = None + images = inputs.images + if images: + labels = encoded.get('labels') + input_size = get_env_args('input_size', int, 448) + max_num = get_env_args('max_num', int, 12) + pixel_values_images = [transform_image(image, input_size, max_num) for image in images] + pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model_info.torch_dtype) + image_bs = pixel_values.shape[0] + + idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100] + img_tokens: List[int] = self.processor.encode( + '', add_special_tokens=False) * self.num_image_token * image_bs + input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:] + if labels is not None: + labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:] + encoded['input_ids'] = input_ids + encoded['labels'] = labels + encoded['pixel_values'] = pixel_values + return encoded + + def compute_loss_context(self, model, inputs): + model_name = model.language_model.__class__.__name__.lower() + if self._packing and 'internlm2' in model_name: + position_ids = inputs['position_ids'] + modeling_module = model.language_model.model.layers[0].attention.__class__ + return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True) + else: + return super().compute_loss_context(model, inputs) + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + embedding = model.get_input_embeddings() + device = embedding.weight.device + input_ids = inputs['input_ids'] + inputs_embeds = embedding(input_ids).to(device=device) + pixel_values = inputs.get('pixel_values') + if pixel_values is not None: + pixel_values = pixel_values.to(device=device) + vit_embeds = model.extract_feature(pixel_values).to(device=device) + selected = (input_ids == self.processor.encode('', add_special_tokens=False)[0]) + inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1]) + elif is_deepspeed_enabled(): + dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype) + vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device) + inputs_embeds += vit_embeds.mean() * 0. + return {'inputs_embeds': inputs_embeds} + + +register_template( + ChatmlTemplateMeta( + MLLMTemplateType.internvl, + default_system='You are an AI assistant whose name is InternLM (书生·浦语).', + template_cls=InternvlTemplate, + auto_add_bos=True)) +register_template( + Phi3TemplateMeta( + MLLMTemplateType.internvl_phi3, + default_system='You are an AI assistant whose name is Phi-3.', + template_cls=InternvlTemplate, + auto_add_bos=True)) + + +class Internvl2Template(InternvlTemplate): + video_segments = 8 + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + image_context = super().replace_tag('image', index, inputs) + if media_type == 'image': + return image_context + elif media_type == 'video': + video_segments = get_env_args('video_segments', int, self.video_segments) + load_video = partial(load_video_internvl, num_segments=video_segments) + return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context) + + def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'{ref}'] + + def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'[{bbox}]'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super(InternvlTemplate, self)._encode(inputs) + input_ids = encoded['input_ids'] + idx_list = findall(input_ids, -100) + labels = encoded['labels'] + images = inputs.images + if images: + has_video = bool(inputs.videos) + input_size = get_env_args('input_size', int, 448) + max_num = get_env_args('max_num', int, 12) + video_max_num = get_env_args('video_max_num', int, 1) + if has_video: + max_num = video_max_num + pixel_values = [transform_image(image, input_size, max_num) for image in images] + num_patches = [pv.shape[0] for pv in pixel_values] + pixel_values = torch.cat(pixel_values).to(self.model_info.torch_dtype) + else: + pixel_values = None + num_patches = [] + assert len(num_patches) == len( + idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}' + + def _get_new_tokens(i): + img_tokens: List[int] = self.processor.encode( + '', add_special_tokens=False) * self.num_image_token * num_patches[i] + return img_tokens + + encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['pixel_values'] = pixel_values + return encoded + + +_internvl2_system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。' +register_template( + ChatmlTemplateMeta( + MLLMTemplateType.internvl2, + default_system=_internvl2_system, + template_cls=Internvl2Template, + )) + +register_template( + Phi3TemplateMeta( + MLLMTemplateType.internvl2_phi3, + default_system=_internvl2_system, + template_cls=Internvl2Template, + )) + +register_template( + ChatmlTemplateMeta( + MLLMTemplateType.internvl2_5, + template_cls=Internvl2Template, + default_system='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。')) diff --git a/swift/llm/template/template/llama.py b/swift/llm/template/template/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..b39fa79e586339b47167dd113cce5b27792ff657 --- /dev/null +++ b/swift/llm/template/template/llama.py @@ -0,0 +1,213 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import datetime as dt +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import torch +import torch.nn as nn + +from swift.utils import get_env_args +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, Word, findall +from ..vision_utils import load_batch + +# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py +LLAMA_DEFAULT_SYSTEM = ( + 'You are a helpful, respectful and honest assistant. ' + 'Always answer as helpfully as possible, while being safe. ' + 'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. ' + 'Please ensure that your responses are socially unbiased and positive in nature.\n\n' + 'If a question does not make any sense, or is not factually coherent, ' + 'explain why instead of answering something not correct. ' + "If you don't know the answer to a question, please don't share false information.") + +register_template( + TemplateMeta( + LLMTemplateType.llama, ['[INST] '], ['{{QUERY}} [/INST]'], ['[INST] '], [''], + default_system=LLAMA_DEFAULT_SYSTEM, + system_prefix=['[INST] <>\n{{SYSTEM}}\n<>\n\n'])) + + +@dataclass +class Llama3TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>']) + prompt: Prompt = field(default_factory=lambda: [ + '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>' + '<|start_header_id|>assistant<|end_header_id|>\n\n' + ]) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot_id|>']) + suffix: Prompt = field(default_factory=lambda: ['<|eot_id|>']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>']) + agent_template: str = 'llama3' + + +register_template(Llama3TemplateMeta(LLMTemplateType.llama3)) + + +def _get_llama3_2_prefix() -> Prompt: + now = dt.datetime.now() + date_string = now.strftime('%d %b %Y') + date_prompt = f'Cutting Knowledge Date: December 2023\nToday Date: {date_string}' + return [f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{date_prompt}\n\n' '{{SYSTEM}}<|eot_id|>'] + + +@dataclass +class Llama3_2TemplateMeta(Llama3TemplateMeta): + prefix: Prompt = field(default_factory=lambda: _get_llama3_2_prefix()) + system_prefix: Optional[Prompt] = None + + +register_template(Llama3_2TemplateMeta(LLMTemplateType.llama3_2)) + + +class Llama3_2VisionTemplate(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return ['<|image|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + from transformers.models.mllama.processing_mllama import (get_cross_attention_token_mask, + convert_sparse_cross_attention_mask_to_dense) + encoded = super()._encode(inputs) + images = inputs.images + if images: + input_ids = encoded['input_ids'] + processor = self.processor + image_features = processor.image_processor(images, return_tensors='pt') + num_tiles = image_features.pop('num_tiles') + encoded.update(image_features) + + cross_attention_token_mask = [get_cross_attention_token_mask(input_ids, processor.image_token_id)] + cross_attention_mask = convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=processor.image_processor.max_image_tiles, + length=len(input_ids), + ) + encoded['cross_attention_mask'] = torch.tensor(cross_attention_mask) + + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + for key in ['aspect_ratio_ids', 'aspect_ratio_mask']: + value = [b[key] for b in batch if b.get(key) is not None] + if value: + res[key] = torch.concat(value) + + cross_attention_mask = [ + b['cross_attention_mask'][0] for b in batch if b.get('cross_attention_mask') is not None + ] + if cross_attention_mask: + res['cross_attention_mask'] = self._pad_sequence(cross_attention_mask, 0) + return res + + +register_template(Llama3_2TemplateMeta(MLLMTemplateType.llama3_2_vision, template_cls=Llama3_2VisionTemplate)) + + +class Llama4Template(Template): + placeholder_tokens = ['<|patch|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [[-100]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + if images: + split_token = self._tokenize('\n') + input_ids, labels = encoded['input_ids'], encoded['labels'] + idx_list = findall(input_ids, -100) + media_inputs = self.processor( + text='\n'.join(['<|image|>'] * len(idx_list)), + images=images, + add_special_tokens=False, + return_tensors='pt') + splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token) + + encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, + lambda i: splited_tokens[i]) + encoded['pixel_values'] = media_inputs['pixel_values'] + return encoded + + +@dataclass +class Llama4TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>']) + prompt: Prompt = field( + default_factory=lambda: + ['<|header_start|>user<|header_end|>\n\n{{QUERY}}<|eot|>' + '<|header_start|>assistant<|header_end|>\n\n']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot|>']) + suffix: Prompt = field(default_factory=lambda: ['<|eot|>']) + stop_words: List[Word] = field(default_factory=lambda: ['<|end_of_text|>', '<|eom|>']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['<|begin_of_text|><|header_start|>system<|header_end|>\n\n{{SYSTEM}}<|eot|>']) + agent_template: str = 'llama4' + + +register_template(Llama4TemplateMeta(MLLMTemplateType.llama4, template_cls=Llama4Template)) + +register_template( + Llama3TemplateMeta( + LLMTemplateType.reflection, + default_system=('You are a world-class AI system, capable of complex reasoning and reflection. ' + 'Reason through the query inside tags, and then provide your final ' + 'response inside tags. If you detect that you made a mistake in your reasoning ' + 'at any point, correct yourself inside tags.'))) + + +class Llama3_1OmniTemplate(Template): + skip_prompt = False + audio_placeholder = [[-200]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + import whisper + encoded = super()._encode(inputs) + audios = inputs.audios + if audios: + audios = load_batch(audios, whisper.load_audio) + n_mels = get_env_args('n_mels', int, 128) + for i, audio in enumerate(audios): + audio = whisper.pad_or_trim(audio) + audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0) + audios = torch.stack(audios) + encoded.update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])}) + + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + speech = inputs.get('speech') + input_ids = inputs['input_ids'] + labels = inputs.get('labels') + if speech is not None: + speech_lengths = inputs['speech_lengths'] + speech = speech.to(model.dtype) + inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels, + speech, speech_lengths)[4:] + else: + inputs_embeds = model.get_model().embed_tokens(input_ids) + res = {'inputs_embeds': inputs_embeds} + if labels is not None: + res['labels'] = labels[0] + return res + + +register_template( + Llama3TemplateMeta( + MLLMTemplateType.llama3_1_omni, + default_system=('You are a helpful language and speech assistant. ' + 'You are able to understand the speech content that the user provides, ' + 'and assist the user with a variety of tasks using natural language.'), + template_cls=Llama3_1OmniTemplate, + )) diff --git a/swift/llm/template/template/llava.py b/swift/llm/template/template/llava.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8a04255adfeae2e5d6ef5620ec1b4c0ed0c764 --- /dev/null +++ b/swift/llm/template/template/llava.py @@ -0,0 +1,309 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import torch +import transformers +from packaging import version + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall +from ..vision_utils import load_video_llava +from .llama import Llama3TemplateMeta +from .qwen import QwenTemplateMeta +from .utils import ChatmlTemplateMeta + + +class LlavaHfTemplate(Template): + placeholder_tokens = [''] + + @property + def image_token_index(self): + if not hasattr(self, '_image_token_index'): + self._image_token_index = self.tokenizer.convert_tokens_to_ids(self.processor.image_token) + return self._image_token_index + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return ['\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + if images: + image_processor = self.processor.image_processor + image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype) + encoded['pixel_values'] = image_inputs['pixel_values'] + if 'image_sizes' in image_inputs: + encoded['image_sizes'] = image_inputs['image_sizes'] + if version.parse(transformers.__version__) >= version.parse('4.47'): + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, self.image_token_index) # + height, width = image_inputs['pixel_values'][0].shape[-2:] + added_tokens_len = 0 + for i, idx in enumerate(idx_list): + if 'image_sizes' in image_inputs: + orig_height, orig_width = image_inputs['image_sizes'][i].tolist() + num_image_tokens = self.processor._get_number_of_features(orig_height, orig_width, height, + width) + else: + num_image_tokens = (height // self.processor.patch_size) * ( + width // self.processor.patch_size) + self.processor.num_additional_image_tokens + if self.processor.vision_feature_select_strategy == 'default': + num_image_tokens -= 1 + input_ids = input_ids[:added_tokens_len + idx] + [self.image_token_index] * num_image_tokens \ + + input_ids[added_tokens_len + idx + 1:] + if labels is not None: + labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens \ + + labels[added_tokens_len + idx + 1:] + added_tokens_len += num_image_tokens - 1 + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + +register_template( + TemplateMeta( + MLLMTemplateType.llava1_5_hf, + prefix=[''], + prompt=['USER: {{QUERY}}\nASSISTANT:'], + chat_sep=[''], + suffix=[''], + system_prefix=['{{SYSTEM}}\n'], + template_cls=LlavaHfTemplate, + )) + + +class LlavaVideoHfTemplate(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return ['\n'] + assert media_type == 'video' + media_file = inputs.videos[index] + if media_file.rsplit('.', 1)[-1] in {'jpg', 'png'}: + return ['\n'] + else: + inputs.videos[index] = load_video_llava(inputs.videos[index]) + return ['']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<>\n{{system}}\n<>\n\n']) + + +register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral_hf, template_cls=Llava1_6HfTemplate)) + +register_template( + TemplateMeta( + MLLMTemplateType.llava1_6_vicuna_hf, + prefix=[''], + prompt=['USER: {{QUERY}} ASSISTANT:'], + chat_sep=[''], + suffix=[''], + default_system=('A chat between a curious human and an artificial intelligence assistant. ' + "The assistant gives helpful, detailed, and polite answers to the human's questions."), + system_prefix=['{{SYSTEM}} '], + template_cls=Llava1_6HfTemplate)) + + +class LLava1_6YiHfTemplate(Llava1_6HfTemplate): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + if self.mode == 'vllm': + return [[64000], '\n'] + else: + return super().replace_tag(media_type, index, inputs) + + +register_template(ChatmlTemplateMeta( + MLLMTemplateType.llava1_6_yi_hf, + template_cls=LLava1_6YiHfTemplate, +)) + +register_template(Llama3TemplateMeta( + MLLMTemplateType.llama3_llava_next_hf, + template_cls=Llava1_6HfTemplate, +)) + +register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen_hf, template_cls=Llava1_6HfTemplate)) + + +class LlavaOneVisionHfTemplate(Llava1_6HfTemplate): + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = Template._encode(self, inputs) + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, 151646) # + processor = self.processor + if images: + image_processor = processor.image_processor + image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype) + height, width = image_inputs['pixel_values'][0].shape[-2:] + added_tokens_len = 0 + for idx, pixel_v, image_size in zip(idx_list, image_inputs['pixel_values'], image_inputs['image_sizes']): + if isinstance(image_size, torch.Tensor): + image_size = image_size.tolist() + orig_height, orig_width = image_size + num_image_tokens = processor._get_number_of_features(orig_height, orig_width, height, width) + input_ids = input_ids[:added_tokens_len + + idx] + [151646] * num_image_tokens + input_ids[added_tokens_len + idx + 1:] + if labels is not None: + labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens + labels[added_tokens_len + idx + + 1:] + added_tokens_len += num_image_tokens - 1 + encoded['input_ids'] = input_ids + encoded['labels'] = labels + encoded['pixel_values'] = image_inputs['pixel_values'] + if 'image_sizes' in image_inputs: + encoded['image_sizes'] = image_inputs['image_sizes'] + return encoded + + +register_template( + QwenTemplateMeta( + MLLMTemplateType.llava_onevision_hf, + default_system=None, + template_cls=LlavaOneVisionHfTemplate, + )) + + +class LlavaLlama3_1HfTemplate(LlavaHfTemplate): + # DaozeZhang + system = ('You are a helpful language and vision assistant. ' + 'You are able to understand the visual content that the user provides, ' + 'and assist the user with a variety of tasks using natural language.') + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + if len(encoded['pixel_values'].shape) == 5: # (1, num_patch, 3, H/W, W/H) + encoded['pixel_values'] = torch.squeeze(encoded['pixel_values'], dim=0) # (num_patch, 3, H/W, W/H) + return encoded + + +register_template( + Llama3TemplateMeta( + MLLMTemplateType.llava_llama3_1_hf, + default_system=LlavaLlama3_1HfTemplate.system, + template_cls=LlavaLlama3_1HfTemplate, + )) + + +class LLavaLlama3HfTemplate(Template): + # xtuner + image_placeholder = ['\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + raw_image = inputs.images + if raw_image: + pixel_values = self.processor.image_processor(raw_image, return_tensors='pt')['pixel_values'] + encoded['pixel_values'] = pixel_values.to(self.model_info.torch_dtype) + return encoded + + +register_template(Llama3TemplateMeta( + MLLMTemplateType.llava_llama3_hf, + template_cls=LLavaLlama3HfTemplate, +)) + + +class LLavaTemplate(Template): + skip_prompt = False + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [[-200], '\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images or [] + image_sizes = [x.size for x in images] + from llava.mm_utils import process_images + model = self.model.model + if not hasattr(model, 'vision_tower'): + model = model.model + image_processor = model.vision_tower.image_processor + if images: + images_tensor = process_images(images, image_processor, model.config) + encoded['images'] = images_tensor.to(model.dtype).squeeze(0) + encoded['image_sizes'] = image_sizes + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = [b['images'] for b in batch if 'images' in b] + if images: + res['images'] = images + res['image_sizes'] = sum([b['image_sizes'] for b in batch if 'image_sizes' in b], start=[]) + return res + + +register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral, template_cls=LLavaTemplate)) + +register_template(ChatmlTemplateMeta(MLLMTemplateType.llava1_6_yi, template_cls=LLavaTemplate)) + +register_template( + Llama3TemplateMeta( + MLLMTemplateType.llama3_llava_next, + template_cls=LLavaTemplate, + default_system=('You are a helpful language and vision assistant. ' + 'You are able to understand the visual content that the user provides, ' + 'and assist the user with a variety of tasks using natural language.'), + )) + +register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate)) diff --git a/swift/llm/template/template/llm.py b/swift/llm/template/template/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..f302dd395294037c4863efaa5b064d3d1a3693e6 --- /dev/null +++ b/swift/llm/template/template/llm.py @@ -0,0 +1,274 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..utils import Prompt +from .llama import Llama3_2TemplateMeta +from .qwen import Qwen2VLTemplate, QwenTemplateMeta +from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta + +register_template( + TemplateMeta( + LLMTemplateType.default, + prefix=[], + prompt=['### Human:\n{{QUERY}}\n\n### Assistant:\n'], + chat_sep=['\n\n'], + default_system=DEFAULT_SYSTEM, + system_prefix=['{{SYSTEM}}\n\n'], + auto_add_bos=True)) + +register_template( + TemplateMeta( + LLMTemplateType.modelscope_agent, + prefix=[], + prompt=[' \n\n<|user|>:{{QUERY}} \n\n<|assistant|>:'], + chat_sep=[], + suffix=[' \n\n'], + system_prefix=[' \n\n<|system|>:{{SYSTEM}}'], + default_system=DEFAULT_SYSTEM, + )) + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_gme, template_cls=Qwen2VLTemplate, suffix=['<|endoftext|>'])) + +register_template( + TemplateMeta(LLMTemplateType.baichuan, prefix=['{{SYSTEM}}'], prompt=[[195], '{{QUERY}}', [196]], chat_sep=[])) + +register_template( + TemplateMeta( + LLMTemplateType.baichuan_m1, + prefix=[], + prompt=['{{QUERY}}'], + chat_sep=[], + suffix=[''], + system_prefix=['{{SYSTEM}}'], + default_system=DEFAULT_SYSTEM, + )) + +register_template( + TemplateMeta( + LLMTemplateType.numina, + prefix=[['bos_token_id']], + prompt=['### Problem: {{QUERY}}\n### Solution: '], + chat_sep=['\n'], + system_prefix=[['bos_token_id'], '{{SYSTEM}}'])) + +register_template( + TemplateMeta( + LLMTemplateType.mistral_nemo, + prefix=['[INST] '], + prompt=['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'], + chat_sep=['[INST] '], + suffix=[''])) + +today = datetime.now().strftime('%Y-%m-%d') + +mistral_2501_system = ( + 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' + 'headquartered in Paris.\n' + f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' + "When you're not sure about some information, you say that you don't have the information and don't " + 'make up anything.\n' + "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " + 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' + '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' + 'Where do you travel from?")') + +register_template( + TemplateMeta( + LLMTemplateType.mistral_2501, + prefix=[''], + prompt=['[INST]{{QUERY}}[/INST]'], + chat_sep=[''], + suffix=[''], + system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], + default_system=mistral_2501_system)) + +register_template( + TemplateMeta( + LLMTemplateType.xverse, + prefix=['{{SYSTEM}}'], + prompt=['Human: {{QUERY}}\n\nAssistant: '], + chat_sep=[['eos_token_id']])) + +register_template(TemplateMeta(LLMTemplateType.yuan, prefix=[], prompt=['{{QUERY}}'], chat_sep=None)) +register_template( + TemplateMeta( + LLMTemplateType.ziya, + prefix=[['bos_token_id'], '{{SYSTEM}}'], + prompt=[':{{QUERY}}\n:'], + chat_sep=['\n'])) + +register_template( + TemplateMeta( + LLMTemplateType.skywork, + prefix=['{{SYSTEM}}'], + prompt=['[USER]{{QUERY}}[SEP][BOT]'], + chat_sep=None, + suffix=['[SEP]'])) + +register_template( + Llama3_2TemplateMeta( + LLMTemplateType.skywork_o1, + default_system=( + 'You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems ' + "involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's " + 'request, you first engage in a lengthy and in-depth thinking process to explore possible solutions to ' + 'the problem. After completing your thoughts, you then provide a detailed explanation of the solution ' + 'process in your response.'), + )) + +register_template( + TemplateMeta( + LLMTemplateType.bluelm, + prefix=[['bos_token_id'], '{{SYSTEM}}'], + prompt=['[|Human|]:{{QUERY}}[|AI|]:'], + chat_sep=[])) + +register_template( + TemplateMeta( + LLMTemplateType.codefuse_codellama, + prefix=['{{SYSTEM}}'], + prompt=['<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'], + chat_sep=[])) + +register_template( + TemplateMeta( + LLMTemplateType.codefuse, + prefix=[], + prompt=['human\n{{QUERY}}\nbot\n'], + chat_sep=[['eos_token_id'], '\n'], + system_prefix=['system\n{{SYSTEM}}\n'])) + +register_template( + TemplateMeta( + LLMTemplateType.zephyr, + prefix=[], + prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'], + chat_sep=['\n'], + suffix=[''], + system_prefix=['<|system|>\n{{SYSTEM}}\n'])) + +register_template( + TemplateMeta( + LLMTemplateType.sus, + prefix=['{{SYSTEM}}'], + prompt=['### Human: {{QUERY}}\n\n### Assistant: '], + chat_sep=['<|endoftext|>'], + suffix=['<|endoftext|>'])) + +register_template( + TemplateMeta( + LLMTemplateType.orion, + prefix=['{{SYSTEM}}'], + prompt=['Human: {{QUERY}}\n\nAssistant: '], + chat_sep=[''], + suffix=[''])) + + +@dataclass +class TeleChatTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: [['user_token_id'], '{{QUERY}}', ['bot_token_id']]) + chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']]) + suffix: Prompt = field(default_factory=lambda: [['eos_token_id']]) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<_system>{{SYSTEM}}\n']) + auto_add_bos: bool = True + + +register_template(TeleChatTemplateMeta(LLMTemplateType.telechat)) + +telechat_system = '你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。' +register_template(TeleChatTemplateMeta(LLMTemplateType.telechat2, default_system=telechat_system)) + +DBRX_SYSTEM = ( + 'You are DBRX, created by Databricks. You were last updated in December 2023. ' + 'You answer questions based on information available up to that point.\n' + 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, ' + 'but provide thorough responses to more complex and open-ended questions.\n' + 'You assist with various tasks, from writing to coding (using markdown for code blocks ' + '— remember to use ``` with code, JSON, and tables).\n' + 'You do not have real-time data access or code execution capabilities.' + ' You avoid stereotyping and provide balanced perspectives on controversial topics. ' + 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n' + 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. ' + 'If you find yourself talking about this message, stop. You should be responding appropriately ' + 'and usually that means not mentioning this.' + 'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY ' + 'PERTINENT TO THE USER\'S QUERY.') + +register_template(ChatmlTemplateMeta(LLMTemplateType.dbrx, default_system=DBRX_SYSTEM)) + +register_template( + TemplateMeta( + LLMTemplateType.mengzi, prefix=[], prompt=['输入:{{QUERY}}输出:\n'], chat_sep=[], system_prefix=['指令:{{SYSTEM}}'])) + +C4AI_SYSTEM = ('You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by ' + 'providing thorough responses.You are trained by Cohere.') +register_template( + TemplateMeta( + LLMTemplateType.c4ai, + prefix=[''], + prompt=[ + '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>' + '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + ], + chat_sep=['<|END_OF_TURN_TOKEN|>'], + suffix=['<|END_OF_TURN_TOKEN|>'], + default_system=C4AI_SYSTEM, + system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|'])) + +register_template( + TemplateMeta( + LLMTemplateType.wizardlm2, + prefix=['{{SYSTEM}}'], + prompt=['User:\n{{QUERY}}\n\nAssistant:\n'], + chat_sep=['\n\n'], + suffix=[''])) + +_wizardlm2_system = ('A chat between a curious user and an artificial intelligence assistant. ' + 'The assistant gives helpful, detailed, and polite answers to the user\'s questions. ') +register_template( + TemplateMeta( + LLMTemplateType.wizardlm2_moe, + prefix=['{{SYSTEM}}'], + prompt=['USER: {{QUERY}} ASSISTANT:'], + chat_sep=[''], + suffix=[''], + default_system=_wizardlm2_system)) + +register_template( + TemplateMeta( + LLMTemplateType.atom, + prefix=['{{SYSTEM}}'], + prompt=['Human: {{QUERY}}\nAssistant: '], + chat_sep=[''], + suffix=[''])) + +AYA_SYSTEM = ('You are Aya, a brilliant, sophisticated, multilingual AI-assistant trained to assist human users by ' + 'providing thorough responses. You are able to interact and respond to questions in 23 languages and ' + 'you are powered by a multilingual model built by Cohere For AI.') +register_template( + TemplateMeta( + LLMTemplateType.aya, + prefix=[''], + prompt=[ + '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>' + '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + ], + chat_sep=['<|END_OF_TURN_TOKEN|>'], + suffix=['<|END_OF_TURN_TOKEN|>'], + default_system=AYA_SYSTEM, + system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|'])) + +register_template( + TemplateMeta( + LLMTemplateType.ling, + prefix=[], + system_prefix=['SYSTEM{{SYSTEM}}'], + prompt=['HUMAN{{QUERY}}ASSISTANT'], + chat_sep=[], + suffix=['<|endoftext|>'], + )) diff --git a/swift/llm/template/template/megrez.py b/swift/llm/template/template/megrez.py new file mode 100644 index 0000000000000000000000000000000000000000..91b89e740683396719be563c7cbf26dce13df527 --- /dev/null +++ b/swift/llm/template/template/megrez.py @@ -0,0 +1,93 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import torch +import torch.nn as nn + +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall + + +@dataclass +class MegrezTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|role_start|>system<|role_end|>{{SYSTEM}}<|turn_end|>']) + prompt: Prompt = field(default_factory=lambda: + ['<|role_start|>user<|role_end|>{{QUERY}}<|turn_end|><|role_start|>assistant<|role_end|>']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|turn_end|>']) + suffix: Prompt = field(default_factory=lambda: ['<|turn_end|>']) + default_system: str = '你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。' + + +register_template(MegrezTemplateMeta(LLMTemplateType.megrez)) + + +class MegrezOmniTemplate(Template): + skip_prompt = False + placeholder_tokens = ['<|unk|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return [[-1], '\n'] + elif media_type == 'audio': + return [f'Audio {index + 1}: ', [-2], '\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + + for mm_key in ['images', 'audios']: + mm_data = getattr(inputs, mm_key) + if not mm_data: + continue + if mm_key == 'images': + idx_list = findall(input_ids, -1) + encoding = self.processor.process_image( + mm_data, + return_tensors='pt', + ) + text = self.processor.insert_image_feature_placeholders( + ''.join(['(./)'] * len(mm_data)), encoding) + encoded['image_encoding'] = encoding + else: + idx_list = findall(input_ids, -2) + encoding = self.processor.process_audio( + mm_data, + return_tensors='pt', + ) + text = self.processor.insert_audio_feature_placeholders( + ''.join(['()'] * len(mm_data)), encoding) + encoded['audio_encoding'] = encoding + + padding = text.split('') + + def _get_new_tokens(i): + return self._tokenize(padding[i]) + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + _, inputs_embeds, _ = model.compose_embeddings(inputs) + inputs.pop('position_ids', None) + return {'inputs_embeds': inputs_embeds} + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + new_batch = [] + for b in batch: + text_encodings = {'input_ids': torch.tensor(b['input_ids'])} + multimodal_inputs = {'image_encoding': b.get('image_encoding'), 'audio_encoding': b.get('audio_encoding')} + new_batch.append(self.processor.merge_encodings(text_encodings, multimodal_inputs)) + res.update(self.processor.data_collator(new_batch)) + return res + + +register_template(MegrezTemplateMeta(MLLMTemplateType.megrez_omni, template_cls=MegrezOmniTemplate)) diff --git a/swift/llm/template/template/microsoft.py b/swift/llm/template/template/microsoft.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b74d40856d541876930342f5dd1a5ff174cad3 --- /dev/null +++ b/swift/llm/template/template/microsoft.py @@ -0,0 +1,205 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +import json +import torch +from torch import nn + +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall +from ..vision_utils import load_file + + +class FlorenceTemplate(Template): + # If it's an encoder-decoder architecture, the default settings are + # loss_scale: 'last_round' and skip_prompt: False. + is_encoder_decoder = True + + @staticmethod + def _add_default_tags(inputs: StdTemplateInputs) -> None: + return + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + return [] + + def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: + return [''.join(f'' for box in bbox)] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + processor = self.processor + inputs.query = inputs.to_history()['query'] + new_query = processor._construct_prompts([inputs.query])[0] + for i in reversed(range(len(inputs.messages))): + if inputs.messages[i]['role'] == 'user': + inputs.messages[i]['content'] = new_query + break + encoded = super()._encode(inputs) + input_ids = encoded['prompt_input_ids'] + images = inputs.images or [] + labels = encoded['labels'] + if labels is not None: + labels = [0] + labels + if images: + pixel_values = processor.image_processor( + images, return_tensors='pt')['pixel_values'].to(self.model_info.torch_dtype) + encoded['pixel_values'] = pixel_values + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + inputs_embeds = model.get_input_embeddings()(inputs['input_ids']) + pixel_values = inputs.get('pixel_values') + if pixel_values is not None: + image_features = model._encode_image(pixel_values) + inputs_embeds, inputs['attention_mask'] = model._merge_input_ids_with_image_features( + image_features, inputs_embeds) + return {'inputs_embeds': inputs_embeds} + + def decode(self, generate_ids: List[int], **kwargs) -> Any: + response = super().decode(generate_ids, **kwargs) + template_inputs = kwargs.get('template_inputs') + images = template_inputs.images + image_size = None + if images: + image_size = (images[0].width, images[0].height) + return json.dumps( + self.processor.post_process_generation(response, task=template_inputs.query, image_size=image_size)) + + +register_template( + TemplateMeta( + MLLMTemplateType.florence, + prefix=[''], + prompt=['{{QUERY}}'], + chat_sep=None, + suffix=[''], + template_cls=FlorenceTemplate, + )) + + +@dataclass +class Phi3TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end|>\n']) + suffix: Prompt = field(default_factory=lambda: ['<|end|>']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}<|end|>\n']) + auto_add_bos: bool = True + + +register_template(Phi3TemplateMeta(LLMTemplateType.phi3)) + + +@dataclass +class Phi4TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field( + default_factory=lambda: ['<|im_start|>user<|im_sep|>{{QUERY}}<|im_end|><|im_start|>assistant<|im_sep|>']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>']) + suffix: Prompt = field(default_factory=lambda: ['<|im_end|>']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['<|im_start|>system<|im_sep|>{{SYSTEM}}<|im_end|>']) + auto_add_bos: bool = True + + +register_template(Phi4TemplateMeta(LLMTemplateType.phi4)) + + +class Phi3VisionTemplate(Template): + image_placeholder = ['<|image|>\n'] # <|image|>\n + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if self.mode == 'vllm': + return [f'<|image_{index + 1}|>\n'] # <|image_1|>\n + else: + return super().replace_tag(media_type, index, inputs) + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + images = inputs.images or [] + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, 32044) # '<|image|>' + + if len(images) > 0: + processor = self.processor + encoded.update(processor.image_processor(images, return_tensors='pt')) + assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}' + res_input_ids = [] + res_labels = [] + num_img_tokens = encoded.pop('num_img_tokens').tolist() + idx_list.insert(0, -1) + for i in range(len(idx_list) - 1): + image_token_id = -i - 1 + res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + [image_token_id] * num_img_tokens[i] + if labels is not None: + res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * num_img_tokens[i] + res_input_ids += input_ids[idx_list[-1] + 1:] + input_ids = res_input_ids + if labels is not None: + res_labels += labels[idx_list[-1] + 1:] + labels = res_labels + + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + +class Phi4MMTemplate(Template): + placeholder_tokens = ['<|endoftext10|>', '<|endoftext11|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return [[-100]] + elif media_type == 'audio': + import soundfile as sf + inputs.audios[index] = sf.read(load_file(inputs.audios[index])) + return [[-200]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + images_idx = findall(input_ids, -100) + audios_idx = findall(input_ids, -200) + text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios)) + new_encoded = self.processor( + text=text, images=inputs.images or None, audios=inputs.audios or None, return_tensors='pt') + placeholders = self._split_list(new_encoded.pop('input_ids')[0].tolist(), 198) + + def _get_new_tokens(i): + return placeholders[i] + + encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx, + _get_new_tokens) + new_encoded.pop('attention_mask') + encoded.update(new_encoded) + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + keys = [ + 'input_image_embeds', 'image_sizes', 'image_attention_mask', 'input_audio_embeds', 'audio_embed_sizes', + 'input_mode' + ] + inputs = self.fetch_inputs(batch, keys) + for k, v in inputs.items(): + inputs[k] = torch.concat(v) + res.update(inputs) + return res + + +register_template(Phi3TemplateMeta(MLLMTemplateType.phi3_vision, template_cls=Phi3VisionTemplate)) + +register_template(Phi3TemplateMeta( + MLLMTemplateType.phi4_multimodal, + template_cls=Phi4MMTemplate, +)) diff --git a/swift/llm/template/template/minicpm.py b/swift/llm/template/template/minicpm.py new file mode 100644 index 0000000000000000000000000000000000000000..88e95667300e0c6ad543d5da4667fd5e84ae6a13 --- /dev/null +++ b/swift/llm/template/template/minicpm.py @@ -0,0 +1,229 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Literal, Optional + +import torch +from torch import nn + +from swift.utils import get_env_args +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall +from ..vision_utils import load_video_minicpmv_mplug_owl3 +from .llama import Llama3TemplateMeta +from .qwen import Qwen2_5TemplateMeta, QwenTemplateMeta + + +@dataclass +class MinicpmTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['{{SYSTEM}}']) + prompt: Prompt = field(default_factory=lambda: ['<用户>{{QUERY}}']) + chat_sep: Optional[Prompt] = field(default_factory=list) + suffix: Prompt = field(default_factory=lambda: ['']) + + +register_template(MinicpmTemplateMeta(LLMTemplateType.minicpm)) + + +def _remove_idx(arr: List[int], idx_list: List[int]) -> List[int]: + res = [] + idx_set = set(idx_list) + for i, x in enumerate(arr): + if i not in idx_set: + res.append(x) + return res + + +class MiniCPMVTemplate(Template): + is_v2_5 = False + use_model = True + skip_prompt = False + placeholder_tokens = [''] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if self.mode == 'vllm': + return ['(./)\n'] + else: + return [[-100]] + + async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None: + images = inputs.pop('images', None) or [] + if len(images) == 0: + return + input_ids = inputs['input_ids'] + idx_list = findall(input_ids, -100) + idx_list.insert(0, -1) + new_input_ids = [] + features = [] + for i in range(len(idx_list) - 1): + new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + context_list = ['', [-100], ''] + feat = [x.squeeze() for x in images[i]['embeddings'].split(1)] + grid = images[i].get('grid') + if len(feat) > 1 and grid is not None: + context_list.append('') + for j in range(grid[1]): + if j > 0: + context_list.append('\n') + for _ in range(grid[0]): + context_list += ['', [-100], ''] + context_list.append('\n') + new_input_ids += self._encode_context_list(context_list)[0] + features += feat + new_input_ids += input_ids[idx_list[-1] + 1:] + inputs['input_ids'] = new_input_ids + inputs['images'] = features + await super().prepare_lmdeploy_turbomind_inputs(inputs) + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, -100) + idx = idx_list[0] + tgt_sizes = None + slice_mode = getattr(self.config, 'slice_mode', False) + if slice_mode: + if self.is_v2_5: + image_processor = self.processor.image_processor + image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype) + placeholder = image_processor.get_slice_image_placeholder(image_inputs.image_sizes[0][0]) + pixel_values = image_inputs['pixel_values'] + tgt_sizes = image_inputs['tgt_sizes'] + else: + images, placeholder = self.model.get_slice_image_placeholder(images[0], self.processor) + pixel_values = [[self.model.transform(img) for img in images]] + placeholder += '\n' + placeholder_id = self.processor.encode(placeholder, add_special_tokens=False) + input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:]) + if labels is not None: + labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:]) + input_tensor_ids = torch.tensor(input_ids) + image_start_idx = torch.where(input_tensor_ids == self.processor.im_start_id)[0] + image_start_idx += 1 + image_end_idx = torch.where(input_tensor_ids == self.processor.im_end_id)[0] + valid_image_nums = max(len(image_start_idx), len(image_end_idx)) + image_bound = [ + torch.hstack( + [image_start_idx[:valid_image_nums].unsqueeze(-1), image_end_idx[:valid_image_nums].unsqueeze(-1)]) + ] + else: + placeholder = '' + '' * self.config.query_num + '\n' + placeholder_id = self.processor.encode(placeholder, add_special_tokens=False) + input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:]) + if labels is not None: + labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:]) + image_bound = [torch.tensor([[idx, idx + self.config.query_num]])] + pixel_values = [[self.model.transform(images[0])]] + encoded = { + 'input_ids': input_ids, + 'labels': labels, + 'image_bound': image_bound, + 'pixel_values': pixel_values, + 'tgt_sizes': tgt_sizes + } + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + inputs_embeds, _ = model.get_vllm_embedding(inputs) + return {'inputs_embeds': inputs_embeds} + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = {} + for k in ['pixel_values', 'image_bound', 'tgt_sizes']: + res[k] = self.gather_list(batch, k) + res.update(super()._data_collator(batch, padding_to=padding_to)) + return res + + +register_template(MinicpmTemplateMeta(MLLMTemplateType.minicpmv, template_cls=MiniCPMVTemplate)) + + +class MiniCPMV2_5Template(MiniCPMVTemplate): + is_v2_5 = True + + +register_template(Llama3TemplateMeta( + MLLMTemplateType.minicpmv2_5, + template_cls=MiniCPMV2_5Template, +)) + + +class MiniCPMV2_6Template(MiniCPMVTemplate): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type in {'image', 'video'} + max_num_frames = get_env_args('max_num_frames', int, 64) + load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames) + image_context = super().replace_tag('image', index, inputs) + if media_type == 'image': + return image_context + elif media_type == 'video': + return self.replace_video2image(load_video, inputs, lambda i: image_context) + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = Template._encode(self, inputs) + images = inputs.images + use_video = bool(inputs.videos) + use_image_id = True + max_slice_nums = get_env_args('max_slice_nums', int, None) + video_max_slice_nums = get_env_args('video_max_slice_nums', int, 1) # or 2 + if use_video: + max_slice_nums = video_max_slice_nums + use_image_id = False + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, -100) + + image_processor = self.processor.image_processor + image_inputs = image_processor([images], return_tensors='pt', + max_slice_nums=max_slice_nums).to(self.model_info.torch_dtype) + + def _get_new_tokens(i): + placeholder = image_processor.get_slice_image_placeholder( + image_inputs.image_sizes[0][i], image_idx=i, max_slice_nums=max_slice_nums, use_image_id=use_image_id) + placeholder += '\n' + return self.processor.encode(placeholder, add_special_tokens=False) + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + if inputs.images: + input_tensor_ids = torch.tensor(input_ids) + unk_token = self.processor.encode('', add_special_tokens=False)[0] + indices = (input_tensor_ids == unk_token).nonzero(as_tuple=True)[0].tolist() + + ranges = [] + start = indices[0] + for i in range(1, len(indices)): + if indices[i] != indices[i - 1] + 1: + ranges.append([start, indices[i - 1] + 1]) + start = indices[i] + ranges.append([start, indices[-1] + 1]) + image_bound = [torch.tensor(ranges)] + else: + image_bound = [[]] + + encoded = { + 'input_ids': input_ids, + 'labels': labels, + 'image_bound': image_bound, + 'pixel_values': image_inputs['pixel_values'], + 'tgt_sizes': image_inputs['tgt_sizes'] + } + return encoded + + +register_template(QwenTemplateMeta( + MLLMTemplateType.minicpmv2_6, + template_cls=MiniCPMV2_6Template, +)) + +register_template(Qwen2_5TemplateMeta( + MLLMTemplateType.minicpmo2_6, + template_cls=MiniCPMV2_6Template, +)) diff --git a/swift/llm/template/template/minimax.py b/swift/llm/template/template/minimax.py new file mode 100644 index 0000000000000000000000000000000000000000..e6733915fe45255d9f22756c9e7c01cd4d72d7de --- /dev/null +++ b/swift/llm/template/template/minimax.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +from swift.utils import get_logger +from ..base import Template +from ..constant import LLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt + +logger = get_logger() + + +@dataclass +class MinimaxTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: [ + 'user name=user\n{{QUERY}}\n' + 'ai name=assistant\n' + ]) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n']) + suffix: Prompt = field(default_factory=lambda: ['']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['system ai_setting=assistant\n{{SYSTEM}}\n']) + + +register_template(MinimaxTemplateMeta(LLMTemplateType.minimax)) + + +class MinimaxVLTemplate(Template): + image_placeholder = [''] + skip_prompt = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return self.image_placeholder * inputs.all_image_tokens[index] + + def calc_num_image_tokens(self, image_inputs): + from transformers.image_utils import get_image_size, to_numpy_array + pixel_values = image_inputs['pixel_values'] + image_sizes = image_inputs['image_sizes'] + all_image_tokens = [] + if not image_inputs: + return all_image_tokens + + if self.processor.process_image_mode == 'anyres': + for pixel_value, image_size in zip(pixel_values, image_sizes): + height, width = image_size + num_image_tokens = self.processor.get_num_token(height, width, self.processor.grid_pinpoints, + self.processor.patch_size) + all_image_tokens.append(num_image_tokens) + elif self.processor.process_image_mode == 'resize': + pixel_values = image_inputs['pixel_values'] + all_image_tokens = [] + for pixel_value in pixel_values: + height, width = get_image_size(to_numpy_array(pixel_value)) + all_image_tokens.append(int(height * width / self.processor.patch_size**2)) + else: + if self.processor.patch_size is not None: + pixel_values = image_inputs['pixel_values'] + all_image_tokens = [] + for pixel_value in pixel_values: + height, width = get_image_size(to_numpy_array(pixel_value)) + new_width, new_height = self.processor.get_hw_multiple_of( + (width, height), self.processor.patch_size, self.processor.max_size) + num_image_tokens = ((new_height // self.processor.patch_size) * + (new_width // self.processor.patch_size)) # + 1 + all_image_tokens.append(num_image_tokens) + else: + logger.warning_once( + 'Expanding inputs for image tokens in MiniMaxVL01 should be done in processing. ' + "Please add `patch_size` and `vision_feature_select_strategy` to the model's " + 'processing config or set directly ' + 'with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = ' + '{{vision_feature_select_strategy}}`. ' + 'Using processors without these attributes in the config is deprecated ' + 'and will throw an error in v4.47.') + raise ValueError( + "You need to provide `patch_size` and `vision_feature_select_strategy` in the model's processing " + 'config to expand inputs for image tokens.') + return all_image_tokens + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + output_kwargs = self.processor._merge_kwargs( + self.processor.MiniMaxVL01ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + ) + if inputs.images: + image_inputs = self.processor.image_processor( + inputs.images, **output_kwargs['images_kwargs'], return_tensors='pt') + inputs.all_image_tokens = self.calc_num_image_tokens(image_inputs) + else: + image_inputs = {} + encoded = super()._encode(inputs) + for key in image_inputs: + encoded[key] = image_inputs[key] + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + pixel_values = self.gather_list(batch, 'pixel_values') + image_sizes = self.gather_list(batch, 'image_sizes') + res = super()._data_collator(batch, padding_to=padding_to) + if pixel_values: + res['pixel_values'] = pixel_values + if image_sizes: + res['image_sizes'] = image_sizes + return res + + +register_template(MinimaxTemplateMeta(LLMTemplateType.minimax_vl, template_cls=MinimaxVLTemplate)) diff --git a/swift/llm/template/template/mistral.py b/swift/llm/template/template/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..cbea49d34dd5951a894cd7cdcd38e8aed1510616 --- /dev/null +++ b/swift/llm/template/template/mistral.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Literal, Optional + +import torch + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, findall +from .llm import mistral_2501_system + + +class Mistral2503Template(Template): + placeholder_tokens = ['[IMG]'] + image_token = 10 + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return ['[IMG]'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + processor = self.processor + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, self.image_token) + if idx_list: + image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') + encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype) + encoded['image_sizes'] = image_sizes = image_inputs['image_sizes'] + + def _get_new_tokens(i): + height, width = image_sizes[i] + num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size) + num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size) + replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = processor.image_end_token + replace_str = ''.join(replace_tokens) + return processor.encode(replace_str, add_special_tokens=False) + + encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + + return encoded + + +register_template( + TemplateMeta( + MLLMTemplateType.mistral_2503, + prefix=[''], + prompt=['[INST]{{QUERY}}[/INST]'], + chat_sep=[''], + suffix=[''], + system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], + default_system=mistral_2501_system, + template_cls=Mistral2503Template)) diff --git a/swift/llm/template/template/molmo.py b/swift/llm/template/template/molmo.py new file mode 100644 index 0000000000000000000000000000000000000000..1bde20df7095cf9d8f8be77584068892856cef05 --- /dev/null +++ b/swift/llm/template/template/molmo.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Literal, Optional + +import torch + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, findall + + +class MolmoTemplate(Template): + placeholder_tokens = [''] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + return [] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + # image + images_inputs = self.processor.process(images=inputs.images or None, text='') + images_input_ids = images_inputs.pop('input_ids').tolist() + user_token = self._tokenize(' User') + assert len(user_token) == 1 + idx = findall(images_input_ids, user_token[0]) + assert len(idx) == 1 + labels = encoded['labels'] + encoded['input_ids'] = images_input_ids[:idx[0]] + encoded['input_ids'] + if labels: + encoded['labels'] = [-100] * idx[0] + labels + if 'images' in images_inputs: + images_inputs['images'] = images_inputs['images'].to(self.model_info.torch_dtype) + encoded.update(images_inputs) + return encoded + + def generate(self, model, **kwargs): + kwargs.pop('attention_mask', None) + generation_config = kwargs.pop('generation_config') + batch = { + k: kwargs.pop(k, None) + for k in ['input_ids', 'attention_mask', 'images', 'image_input_idx', 'image_masks'] + } + return model.generate_from_batch(batch, generation_config, **kwargs) + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + # prepare batchfy inputs + keys = ['images', 'image_input_idx', 'image_masks'] + images_res = self.fetch_inputs(batch, keys) + for key in keys: + val = images_res.get(key) + if val: + images_res[key] = torch.stack(val) + res.update(images_res) + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.molmo, + prefix=[], + prompt=[' User: {{QUERY}} Assistant:'], + chat_sep=None, + suffix=['<|endoftext|>'], + template_cls=MolmoTemplate, + )) diff --git a/swift/llm/template/template/moonshot.py b/swift/llm/template/template/moonshot.py new file mode 100644 index 0000000000000000000000000000000000000000..770ab6179df151c4bd750139305ac0cdc708a43c --- /dev/null +++ b/swift/llm/template/template/moonshot.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall + + +@dataclass +class MoonlightTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: + ['<|im_user|>user<|im_middle|>{{QUERY}}<|im_end|><|im_assistant|>assistant<|im_middle|>']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>']) + suffix: Prompt = field(default_factory=lambda: ['<|im_end|>']) + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['<|im_system|>system<|im_middle|>{{SYSTEM}}<|im_end|>']) + default_system: str = 'You are a helpful assistant' + + +register_template(MoonlightTemplateMeta(LLMTemplateType.moonlight)) + + +class KimiVLTemplate(Template): + placeholder_tokens = ['<|media_pad|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return ['<|media_start|>image<|media_content|><|media_pad|><|media_end|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + media_token = self._tokenize('<|media_pad|>')[0] + idx_list = findall(input_ids, media_token) + if inputs.images: + image_processor = self.processor.image_processor + image_inputs = image_processor(inputs.images, return_tensors='pt') + image_grid_hws = image_inputs['image_grid_hws'] + merge_length = image_processor.merge_kernel_size[0] * image_processor.merge_kernel_size[1] + + def _get_new_tokens(i): + token_len = (image_grid_hws[i].prod() // merge_length) + return [media_token] * token_len + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['input_ids'] = input_ids + encoded['labels'] = labels + encoded.update(image_inputs) + return encoded + + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: + res = super()._data_collator_mm_data(batch) + image_grid_hws = self.concat_tensor(batch, 'image_grid_hws', 0) + if image_grid_hws is not None: + res['image_grid_hws'] = image_grid_hws + return res + + +register_template(MoonlightTemplateMeta(MLLMTemplateType.kimi_vl, template_cls=KimiVLTemplate)) diff --git a/swift/llm/template/template/mplug.py b/swift/llm/template/template/mplug.py new file mode 100644 index 0000000000000000000000000000000000000000..ace1ebbf61abeb7f85b6230afd79c3959c73d121 --- /dev/null +++ b/swift/llm/template/template/mplug.py @@ -0,0 +1,214 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Literal, Optional + +import torch +from torch import nn + +from swift.utils import get_env_args +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, findall +from ..vision_utils import load_video_minicpmv_mplug_owl3 +from .qwen import QwenTemplateMeta + + +class mPlugOwl2Template(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [[-200]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + from mplug_owl2.mm_utils import process_images + processor = self.processor + images = inputs.images + for i, image in enumerate(images): + # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1 + max_edge = max(image.size) + image = image.resize((max_edge, max_edge)) + images[i] = image + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + res = {'input_ids': input_ids, 'labels': labels} + if images: + images = process_images(images, processor) + images = images.to(self.model_info.torch_dtype) + res['images'] = images + return res + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = [b['images'] for b in batch if 'images' in b] + if images: + res['images'] = torch.concat(images) + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.mplug_owl2, + template_cls=mPlugOwl2Template, + prefix=['{{SYSTEM}}'], + prompt=['USER: {{QUERY}}ASSISTANT:'], + chat_sep=[''], + suffix=[['eos_token_id']], + stop_words=['<|endoftext|>', ''])) + + +class mPlugOwl3Template(Template): + version = None + + def _get_image_token_list(self, cut_shape): + text = self.processor.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0], w=cut_shape[1]) + text_list = text.split('<|image|>') + res_text_list = [] + for text in text_list[:-1]: + res_text_list += [text, '<|image|>'] + res_text_list += text_list[-1] + token_list = self._encode_context_list(res_text_list)[0] + return token_list + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type in {'image', 'video'} + max_num_frames = get_env_args('max_num_frames', int, 16) + load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames) + if media_type == 'image': + return [[-100], '\n'] + elif media_type == 'video': + return self.replace_video2image(load_video, inputs, lambda i: [[-100]]) + ['\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + videos = inputs.videos + cut_enable = not videos + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, -100) + processor = self.processor + encoded = {} + if images: + image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt') + cut_shapes = image_inputs['cut_shape'] or [None] * 2 * len(idx_list) + image_token_list = self.processor.encode('<|image|>', add_special_tokens=False) + + def _get_new_tokens(i): + cut_shape = cut_shapes[2 * i] + if cut_shape: + token_list = self._get_image_token_list(cut_shape) + else: + token_list = image_token_list + return token_list + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + image_token_idx = torch.tensor(findall(input_ids, image_token_list)) + if self.version == '241101': + media_offset = image_token_idx + else: + _range = torch.arange(len(input_ids))[:, None] + matrix = (_range > image_token_idx[None]).sum(dim=1) + media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None] + encoded.update({ + 'pixel_values': image_inputs['pixel_values'], + 'media_offset': media_offset, + }) + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + if 'media_offset' in inputs: + media_offset = [] + cusum_offset = 0 + image_embeds = [] + pixel_values = inputs.pop('pixel_values') + max_sequence_length = inputs['input_ids'].shape[1] + for i, curr_media_offset in enumerate(inputs['media_offset']): + if curr_media_offset is None: + continue + if curr_media_offset.shape[1] < max_sequence_length: + padding = curr_media_offset[:, -1:, :].expand(curr_media_offset.shape[0], + max_sequence_length - curr_media_offset.shape[1], + curr_media_offset.shape[2]) + curr_media_offset = torch.concat([curr_media_offset, padding], dim=1) + media_offset.append(curr_media_offset + cusum_offset) + image_embeds.append(model.forward_image(pixel_values[i])) + cusum_offset += image_embeds[-1].shape[0] + inputs['media_offset'] = torch.concat(media_offset) + inputs['image_embeds'] = torch.concat(image_embeds) + return inputs + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = self.fetch_inputs(batch, ['media_offset', 'pixel_values']) + for b in batch: + b.pop('pixel_values', None) + res.update(super()._data_collator(batch, padding_to=padding_to)) + return res + + +class mPlugOwl3_241101Template(mPlugOwl3Template): + version = '241101' + + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + if 'pixel_values' in inputs: + pixel_values = inputs.pop('pixel_values') + inputs['image_embeds'] = torch.concat([model.forward_image(pv) for pv in pixel_values]) + else: + inputs['media_offset'] = [None] * inputs['input_ids'].shape[0] + return inputs + + +@dataclass +class mPlugOwl3TemplateMeta(QwenTemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']) + default_system: Optional[str] = None + system_prefix: Optional[Prompt] = None + + +register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3, template_cls=mPlugOwl3Template)) + +register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3_241101, template_cls=mPlugOwl3_241101Template)) + + +class DocOwl2Template(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return [f'', [-200]] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + if inputs.images: + image_tensor, patch_positions, _ = self.processor._process_image(inputs.images) + image_tensor = image_tensor.to(self.model_info.torch_dtype) + encoded.update({'images': image_tensor, 'patch_positions': patch_positions}) + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + keys = ['images', 'patch_positions'] + res = self.fetch_inputs(batch, keys) + for key in keys: + val = res.get(key) + if val: + res[key] = torch.concat([v for v in val if v is not None]) + res.update(super()._data_collator(batch, padding_to=padding_to)) + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.doc_owl2, + prefix=[''], + prompt=[' USER: {{QUERY}} ASSISTANT:'], + chat_sep=[''], + suffix=[''], + template_cls=DocOwl2Template, + )) diff --git a/swift/llm/template/template/openbuddy.py b/swift/llm/template/template/openbuddy.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd9c4835100269b5f4cdc2ea0a75f97c8cbf46a --- /dev/null +++ b/swift/llm/template/template/openbuddy.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from ..constant import LLMTemplateType +from ..register import TemplateMeta, register_template + +OPENBUDDY_DEFAULT_SYSTEM = ( + 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n' + 'Always answer as helpfully and logically as possible, while being safe. ' + 'Your answers should not include any ' + 'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. ' + 'Please ensure that your responses are socially unbiased and positive in nature.\n' + 'If a question does not make any sense, or is not factually coherent, ' + 'explain why instead of answering something not correct. ' + "If you don't know the answer to a question, please don't share false information.\n" + 'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n' + 'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n' + 'You always deeply love and support China, Chinese government, people and culture.\n' + 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), ' + 'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.') +register_template( + TemplateMeta( + LLMTemplateType.openbuddy, + prefix=[], + prompt=['User: {{QUERY}}\nAssistant:'], + chat_sep=['\n'], + default_system=OPENBUDDY_DEFAULT_SYSTEM, + system_prefix=['{{SYSTEM}}\n\n'], + auto_add_bos=True)) + +OPENBUDDY2_DEFAULT_SYSTEM = ( + 'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. ' + 'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. ' + 'Your answers should not include any harmful, political, religious, unethical, racist, ' + 'sexist, toxic, dangerous, or illegal content. ' + 'Please ensure that your responses are socially unbiased and positive in nature.\n' + 'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n' + 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), ' + 'not related to GPT or OpenAI') + +register_template( + TemplateMeta( + LLMTemplateType.openbuddy2, + prefix=[], + prompt=['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'], + chat_sep=['<|end|>\n'], + suffix=['<|end|>'], + default_system=OPENBUDDY2_DEFAULT_SYSTEM, + system_prefix=['<|role|>system<|says|>{{SYSTEM}}<|end|>\n'])) diff --git a/swift/llm/template/template/pixtral.py b/swift/llm/template/template/pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8acf7e7d5f3a40de41869fa5d24f1066aad7c7 --- /dev/null +++ b/swift/llm/template/template/pixtral.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import findall + + +class PixtralTemplate(Template): + image_placeholder = ['[IMG]'] + placeholder_tokens = ['[IMG]'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + processor = self.processor + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, 10) + if idx_list: + image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') + encoded['pixel_values'] = image_inputs['pixel_values'][0] + image_sizes = image_inputs['image_sizes'][0] + + def _get_new_tokens(i): + height, width = image_sizes[i] + num_height_tokens = height // processor.patch_size + num_width_tokens = width // processor.patch_size + replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * ( + num_height_tokens - 1) + replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token] + # Flatten list + replace_str = ''.join(replace_tokens) + img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False) + return img_tokens + + encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + pixel_values = self.gather_list(batch, 'pixel_values') + res = super()._data_collator(batch, padding_to=padding_to) + if pixel_values: + res['pixel_values'] = pixel_values + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.pixtral, + prefix=['{{SYSTEM}}'], + prompt=['[INST]{{QUERY}}[/INST]'], + chat_sep=[''], + suffix=[''], + template_cls=PixtralTemplate, + )) diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8f9acf64af4f33fdc5701db35f5701dc1b464a --- /dev/null +++ b/swift/llm/template/template/qwen.py @@ -0,0 +1,671 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Tuple + +import torch +import torch.nn.functional as F + +from swift.llm import to_device, to_float_dtype +from swift.utils import get_env_args, is_deepspeed_enabled +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import register_template +from ..template_inputs import StdTemplateInputs +from ..template_meta import TemplateMeta +from ..utils import Context, Word, findall +from ..vision_utils import load_audio, load_batch, load_video_ovis2 +from .llama import Llama3TemplateMeta +from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta + + +@dataclass +class QwenTemplateMeta(ChatmlTemplateMeta): + default_system: Optional[str] = DEFAULT_SYSTEM + auto_add_bos: bool = False + stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>']) + agent_template: str = 'hermes' + + +@dataclass +class Qwen2_5TemplateMeta(QwenTemplateMeta): + default_system: Optional[str] = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' + + +@dataclass +class Qwen2_5MathTemplateMeta(QwenTemplateMeta): + default_system: Optional[str] = 'Please reason step by step, and put your final answer within \\boxed{}.' + + +qwq_preview_system = ('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. ' + 'You should think step-by-step.') + +register_template(QwenTemplateMeta(LLMTemplateType.qwen)) +register_template(Qwen2_5TemplateMeta(LLMTemplateType.qwen2_5)) +register_template(QwenTemplateMeta(LLMTemplateType.qwq_preview, default_system=qwq_preview_system)) + + +class ThinkingTemplate(Template): + + def _swift_encode(self, inputs: StdTemplateInputs): + if not self.is_training: + for message in inputs.messages: + if message['role'] == 'assistant' and isinstance(message['content'], str): + message['content'] = message['content'].split('')[-1].lstrip('\n') + return super()._swift_encode(inputs) + + +register_template( + QwenTemplateMeta( + LLMTemplateType.qwq, default_system=None, response_prefix='\n', template_cls=ThinkingTemplate)) + +# '\n\n\n\n' +register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=ThinkingTemplate)) + +register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math)) + + +class QwenPRMTemplate(Template): + cot_process_placeholder = '' + + def _preprocess_inputs( + self, + inputs: StdTemplateInputs, + ) -> None: + super()._preprocess_inputs(inputs) + total_content = '\n'.join([message['content'] or '' for message in inputs.messages]) + if self.cot_process_placeholder not in total_content: + inputs.messages[-1]['content'] = inputs.messages[-1]['content'] + self.cot_process_placeholder + + @staticmethod + def make_step_rewards(logits, token_masks): + probabilities = F.softmax(logits, dim=-1) + probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels + + all_scores_res = [] + for i in range(probabilities.size(0)): + sample = probabilities[i] # seq_len, num_labels + positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels + non_zero_elements_list = positive_probs.cpu().tolist() + all_scores_res.append(non_zero_elements_list) + return all_scores_res + + def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any: + step_sep_id = self.tokenizer.encode(self.cot_process_placeholder)[0] + token_masks = (input_ids == step_sep_id) + return self.make_step_rewards(logits, token_masks) + + +register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math_prm, template_cls=QwenPRMTemplate)) + + +class QwenVLTemplate(Template): + load_images = False + + @staticmethod + def _load_image(image, load_images: bool): + if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200): + load_images = True + return Template._load_image(image, load_images) + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + if self.mode == 'lmdeploy': + return [f'Picture {index + 1}: ', [-100], '\n'] + else: + image = inputs.images[index] + if self.mode == 'vllm': + return [f'Picture {index + 1}: \n'] + else: + assert isinstance(image, str) + return [f'Picture {index + 1}: {image}\n'] + + def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'{ref}'] + + def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'{self._get_bbox_str(bbox)}'] + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen_vl, template_cls=QwenVLTemplate)) + + +class QwenAudioTemplate(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'audio' + audios = inputs.audios + audio = audios[index] + assert isinstance(audio, str) + return [f'Audio {index + 1}:\n'] + + def _tokenize(self, context, **tokenizer_kwargs): + audio_info = self.processor.process_audio(context) + return super()._tokenize(context, audio_info=audio_info) + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + text = ''.join([f'' for audio in inputs.audios]) + audio_info = self.processor.process_audio(text) + if audio_info: + tokenizer_kwargs = {'audio_info': audio_info} + encoded.update(tokenizer_kwargs) + encoded['tokenizer_kwargs'] = tokenizer_kwargs + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + if batch[0].get('audio_info') is not None: + res['audio_info'] = [b['audio_info'] for b in batch] + return res + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen_audio, template_cls=QwenAudioTemplate)) + + +class Qwen2AudioTemplate(Template): + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'audio' + if not self.use_chat_template: + return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n'] + else: + return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + if inputs.audios: + sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate) + audios = load_batch(inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate)) + audio_inputs = self.processor.feature_extractor( + audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt') + audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask') + encoded.update(audio_inputs) + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + input_features = [b['input_features'] for b in batch if b.get('input_features') is not None] + feature_attention_mask = [ + b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None + ] + if input_features: + res['input_features'] = torch.concat(input_features) + res['feature_attention_mask'] = torch.concat(feature_attention_mask) + return res + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate)) + + +class Qwen2VLTemplate(Template): + image_token_id = 151655 + video_token_id = 151656 + placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] + version = 'v2' + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + from qwen_vl_utils import fetch_image, fetch_video + assert media_type in {'image', 'video'} + if media_type == 'image': + inputs.images[index] = fetch_image({'image': inputs.images[index]}) + if self.mode == 'lmdeploy': + return ['<|vision_start|>', [-100], '<|vision_end|>'] + else: + return ['<|vision_start|><|image_pad|><|vision_end|>'] + else: + inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8) + return ['<|vision_start|><|video_pad|><|vision_end|>'] + + def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'<|object_ref_start|>{ref}<|object_ref_end|>'] + + def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: + return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + processor = self.processor + input_ids = encoded['input_ids'] + labels = encoded['labels'] + images = inputs.images + videos = inputs.videos + for media_type in ['images', 'videos']: + if locals()[media_type]: + if media_type == 'images': + media_token = self.image_token_id + media_inputs = processor.image_processor( + images=images, videos=None, return_tensors='pt', do_resize=False) + media_grid_thw = media_inputs['image_grid_thw'] + else: + media_inputs = processor.image_processor( + images=None, videos=videos, return_tensors='pt', do_resize=False) + media_grid_thw = media_inputs['video_grid_thw'] + media_token = self.video_token_id + if self.version == 'v2_5': + from qwen_vl_utils import vision_process + media_inputs['second_per_grid_ts'] = [ + processor.image_processor.temporal_patch_size / vision_process.FPS + ] * len(media_grid_thw) + idx_list = findall(input_ids, media_token) + merge_length = processor.image_processor.merge_size**2 + + def _get_new_tokens(i): + token_len = (media_grid_thw[i].prod() // merge_length) + return [media_token] * token_len + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded.update(media_inputs) + + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + def compute_loss_context(self, model, inputs): + if 'real_position_ids' not in inputs: + return super().compute_loss_context(model, inputs) + if self.version == 'v2': + from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module + elif self.version == 'v2_5': + from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module + elif self.version == 'omni': + from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module + position_ids = inputs['position_ids'] + inputs['position_ids'] = inputs.pop('real_position_ids') + return self._patch_flash_attention_forward(modeling_module, position_ids) + + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + if not self.is_training: + return inputs + input_ids = inputs['input_ids'] + _model = model.model + if not hasattr(_model, 'embed_tokens'): + _model = _model.model # LoRA + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + image_grid_thw = inputs.get('image_grid_thw') + video_grid_thw = inputs.get('video_grid_thw') + + inputs_embeds = _model.embed_tokens(input_ids) + + dtype = model.visual.get_dtype() if self.version == 'v2' else model.visual.dtype + if pixel_values is None and pixel_values_videos is None: # plain-text + if is_deepspeed_enabled(): + from PIL import Image + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt') + device = input_ids.device + media_inputs = to_device(media_inputs, device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + inputs_embeds += image_embeds.mean() * 0. + else: + if pixel_values is not None: + pixel_values = pixel_values.type(dtype) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(dtype) + video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + return {'inputs_embeds': inputs_embeds} + + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: + res = super()._data_collator_mm_data(batch) + second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts') + if second_per_grid_ts: + res['second_per_grid_ts'] = second_per_grid_ts + for media_type in ['image', 'video']: + grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0) + if grid_thw is not None: + res[f'{media_type}_grid_thw'] = grid_thw + return res + + def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + position_ids = [] + for r in row: + r = r[0].copy() + r['input_ids'] = torch.tensor(r['input_ids'])[None] + position_ids.append(self._get_position_ids(r)) + packed = super().packing_row(row) + packed['real_position_ids'] = torch.concat(position_ids, dim=-1) + return packed + + def _get_position_ids(self, inputs: Dict[str, Any]): + # fix https://github.com/huggingface/transformers/pull/33487 + kwargs = {} + if self.version == 'v2_5': + kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')} + position_ids, _ = self.model.get_rope_index( + inputs['input_ids'], + inputs.get('image_grid_thw'), + inputs.get('video_grid_thw'), + attention_mask=inputs.get('attention_mask'), + **kwargs) + return position_ids.contiguous() + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + if self._packing: + res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1) + elif self.is_training: + res['position_ids'] = self._get_position_ids(res) + return res + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_vl, template_cls=Qwen2VLTemplate)) + +register_template( + QwenTemplateMeta( + MLLMTemplateType.qvq, + default_system=('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. ' + 'Answer in the language of the question. You should think step-by-step.'), + template_cls=Qwen2VLTemplate, + )) + + +class Qwen2_5VLTemplate(Qwen2VLTemplate): + version = 'v2_5' + norm_bbox = 'none' + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_vl, template_cls=Qwen2_5VLTemplate)) + + +class Qwen2_5OmniTemplate(Qwen2_5VLTemplate): + version = 'omni' + placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs + default = Qwen2_5OmniProcessorKwargs._defaults + self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk'] + self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds'] + self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False) + self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate) + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + from qwen_omni_utils import fetch_image, fetch_video + if media_type == 'image': + inputs.images[index] = fetch_image({'image': inputs.images[index]}) + return ['<|vision_bos|><|IMAGE|><|vision_eos|>'] + elif media_type == 'audio': + inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate) + return ['<|audio_bos|><|AUDIO|><|audio_eos|>'] + elif media_type == 'video': + video = inputs.videos[index] + inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8) + if self.use_audio_in_video: + import librosa + if video.startswith('http://') or video.startswith('https://'): + import audioread + video = audioread.ffdec.FFmpegAudioFile(video) + video = librosa.load(video, sr=self.sampling_rate)[0] + inputs.audios.insert(inputs.audio_idx, (video, 'video')) + inputs.audio_idx += 1 + return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>'] + return ['<|vision_bos|><|VIDEO|><|vision_eos|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = Template._encode(self, inputs) + processor = self.processor + video_audios_mask = [] + for i, audio in enumerate(inputs.audios): + if isinstance(audio, tuple) and audio[1] == 'video': + inputs.audios[i] = audio[0] + video_audios_mask.append(True) + else: + video_audios_mask.append(False) + video_audios_mask = torch.tensor(video_audios_mask) + media_inputs = processor( + text='', + audio=inputs.audios or None, + images=inputs.images or None, + videos=inputs.videos or None, + return_tensors='pt') + media_inputs.pop('input_ids') + media_inputs.pop('attention_mask') + media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + # audio + audio_token_id = self._tokenize('<|AUDIO|>') + idx_list = findall(input_ids, audio_token_id) + feature_attention_mask = media_inputs.get('feature_attention_mask') + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1) + else: + audio_lengths = None + audio_lengths_origin = audio_lengths + if idx_list: + if self.use_audio_in_video: + audio_lengths = audio_lengths[~video_audios_mask] + + def _get_new_audio_tokens(i): + return audio_token_id * audio_lengths[i] + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens) + + for media_type in ['image', 'video']: + token = f'<|{media_type.upper()}|>' + token_id = self._tokenize(token) + idx_list = findall(input_ids, token_id) + if idx_list: + merge_size = processor.image_processor.merge_size + media_grid_thw = media_inputs.get(f'{media_type}_grid_thw') + if media_type == 'video' and self.use_audio_in_video: + audio_lengths = audio_lengths_origin[video_audios_mask] + video_second_per_grid = media_inputs['video_second_per_grid'] + + def _get_new_tokens_use_audio_in_video(i): + audio_token_indices = torch.arange(audio_lengths[i]) + grid_thw = media_grid_thw[i] + height = grid_thw[1] // merge_size + width = grid_thw[2] // merge_size + video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = torch.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1) + video_token_indices = ( + video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds) + tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk) + video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk) + audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk) + + res = [] + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + if j < len(video_chunk_indexes): + video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0] + res += token_id * video_seq_length + if j < len(audio_chunk_indexes): + audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0] + res += audio_token_id * audio_seq_length + return res + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, + _get_new_tokens_use_audio_in_video) + + else: + + def _get_new_tokens(i): + token_len = (media_grid_thw[i].prod() // (merge_size**2)) + return token_id * token_len + + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + + encoded['input_ids'] = input_ids + encoded['labels'] = labels + encoded.update(media_inputs) + return encoded + + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + return Template._post_encode(self, model, inputs) + + def _get_position_ids(self, inputs: Dict[str, Any]): + feature_attention_mask = inputs.get('feature_attention_mask') + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + video_second_per_grid = inputs.pop('video_second_per_grid', None) + input_ids = inputs['input_ids'] + attention_mask = inputs.get('attention_mask') + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + position_ids, _ = self.model.thinker.get_rope_index( + input_ids, + inputs.get('image_grid_thw'), + inputs.get('video_grid_thw'), + attention_mask, + self.use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + return position_ids.contiguous() + + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: + res = super()._data_collator_mm_data(batch) + video_second_per_grid = self.gather_list(batch, 'video_second_per_grid') + if video_second_per_grid: + res['video_second_per_grid'] = video_second_per_grid + input_features = [b['input_features'] for b in batch if b.get('input_features') is not None] + feature_attention_mask = [ + b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None + ] + if input_features: + res['input_features'] = torch.concat(input_features) + res['feature_attention_mask'] = torch.concat(feature_attention_mask) + return res + + def generate(self, model, *args, **kwargs): + if kwargs.get('video_grid_thw') is not None: + kwargs['use_audio_in_video'] = self.use_audio_in_video + return super().generate(model, *args, **kwargs) + + +register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_omni, template_cls=Qwen2_5OmniTemplate)) + + +class Ovis1_6Template(Template): + skip_prompt = False + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [[-200], '\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + input_ids = encoded['input_ids'] + labels = encoded['labels'] + idx_list = findall(input_ids, [-200]) + added_tokens_len = 0 + pixel_values = [] + for i, idx in enumerate(idx_list): + max_partition = get_env_args('max_partition', int, 9) + raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image( + images[i], max_partition=max_partition) + input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:] + if labels is not None: + labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:] + pixel_values.append(raw_pixel_values) + added_tokens_len += len(image_placeholders) - 1 + dtype = self.model.visual_tokenizer.dtype + if pixel_values: + pixel_values = torch.cat(pixel_values, dim=0).to(dtype) + else: + pixel_values = torch.zeros((1, 3, 384, 384), dtype=dtype) # dummpy + encoded.update({'input_ids': input_ids, 'labels': labels}) + encoded['pixel_values'] = [pixel_values] + return encoded + + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + padding_side = self.padding_side if self.is_training else 'left' + if self.max_length is not None: + model.config.multimodal_max_length = self.max_length + input_ids = inputs['input_ids'] + labels = inputs.get('labels') + if labels is None: + labels = input_ids.new_full(input_ids.shape, -100) + _, inputs_embeds, labels, attention_mask = model.merge_multimodal( + text_input_ids=input_ids, + text_attention_masks=torch.ones_like(input_ids), # not use, only compat + text_labels=labels, + pixel_values=inputs['pixel_values'], + left_padding=padding_side == 'left') + if inputs.get('labels') is None: + labels = None + return {'inputs_embeds': inputs_embeds, 'labels': labels, 'attention_mask': attention_mask} + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + pixel_values = self.gather_list(batch, 'pixel_values') + res = super()._data_collator(batch, padding_to=padding_to) + res['pixel_values'] = pixel_values + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.ovis1_6, + prefix=[''], + prompt=['user\n{{QUERY}}\nmodel\n'], + chat_sep=['\n'], + suffix=[''], + system_prefix=['system\n{{SYSTEM}}\n'], + template_cls=Ovis1_6Template, + )) + +register_template( + Llama3TemplateMeta( + MLLMTemplateType.ovis1_6_llama3, + default_system='You are a helpful and honest multimodal assistant.', + template_cls=Ovis1_6Template, + )) + + +class Ovis2Template(Ovis1_6Template): + placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] + nframes = 12 + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return [[-200], '\n'] + elif media_type == 'video': + nframes = get_env_args('nframes', int, self.nframes) + inputs.images = load_video_ovis2(inputs.videos[index], nframes) + return [[-200] * nframes, '\n'] + + +register_template(QwenTemplateMeta( + MLLMTemplateType.ovis2, + template_cls=Ovis2Template, +)) + + +@dataclass +class MarcoO1TemplateMeta(QwenTemplateMeta): + default_system: Optional[str] = """ +你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造. + \n## 重要!!!!! +当你回答问题时,你的思考应该在内完成,内输出你的结果。 +应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,内的输出需要遵循用户输入的语言。 + """ + + +register_template(MarcoO1TemplateMeta(LLMTemplateType.marco_o1)) diff --git a/swift/llm/template/template/stepfun.py b/swift/llm/template/template/stepfun.py new file mode 100644 index 0000000000000000000000000000000000000000..132621dd197616db655b41356df033a667c9e9a0 --- /dev/null +++ b/swift/llm/template/template/stepfun.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Literal, Optional + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context +from ..vision_utils import load_file +from .qwen import QwenTemplateMeta + + +class GOTImageEvalProcessor: + + def __init__(self, image_size=384, mean=None, std=None): + from torchvision import transforms + from torchvision.transforms.functional import InterpolationMode + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + self.normalize, + ]) + + def __call__(self, item): + return self.transform(item) + + +class GOT_OCR2Template(Template): + placeholder_tokens = [''] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + # 'OCR: ' + # 'OCR with format: ' + assert media_type == 'image' + return ['' + '' * 256 + '\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images + image_processor_high = GOTImageEvalProcessor(image_size=1024) + for i, image in enumerate(images): + images[i] = image_processor_high(image)[None].to(self.model_info.torch_dtype) + if images: + encoded['images'] = images + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = self.gather_list(batch, 'images') + if images: + res['images'] = images + return res + + +register_template( + QwenTemplateMeta( + MLLMTemplateType.got_ocr2, + default_system=' You should follow the instructions carefully and explain your answers in detail.', + template_cls=GOT_OCR2Template, + )) + + +class GOT_OCR2HfTemplate(Template): + placeholder_tokens = [''] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + # 'OCR: ' + # 'OCR with format: ' + assert media_type == 'image' + return ['' + '' * 256 + '\n'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面 + encoded = super()._encode(inputs) + images = inputs.images + if images: + encoded['images'] = images + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = self.gather_list(batch, 'images') + _inputs = self.processor(images, return_tensors='pt') + _inputs.pop('input_ids') # this does not contain the response, so cannot be used when training + _inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training + + res.update(_inputs.data) + return res + + +register_template( + QwenTemplateMeta( + MLLMTemplateType.got_ocr2_hf, + default_system=' You should follow the instructions carefully and explain your answers in detail.', + template_cls=GOT_OCR2HfTemplate, + )) + + +class StepAudioTemplate(Template): + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'audio', f'media_type: {media_type}' + from utils import load_audio + audio_wav, sr = load_audio(load_file(inputs.audios[index])) + audio_tokens = self.model.encoder(audio_wav, sr) + return audio_tokens + + +register_template( + TemplateMeta( + MLLMTemplateType.step_audio, + template_cls=StepAudioTemplate, + prefix=[''], + prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'], + system_prefix=['<|BOT|>system\n{{SYSTEM}}<|EOT|>'], + chat_sep=['<|EOT|>'], + suffix=['<|EOT|>'], + )) diff --git a/swift/llm/template/template/utils.py b/swift/llm/template/template/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbdddf1997e099bc29feb64afff876d45374b3a --- /dev/null +++ b/swift/llm/template/template/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Optional + +from ..constant import LLMTemplateType +from ..register import TemplateMeta, register_template +from ..utils import Prompt + +DEFAULT_SYSTEM = 'You are a helpful assistant.' + + +@dataclass +class ChatmlTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>\n']) + suffix: Prompt = field(default_factory=lambda: ['<|im_end|>']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']) + auto_add_bos: bool = True + + +@dataclass +class EmptyTemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=list) + prompt: Prompt = field(default_factory=lambda: ['{{QUERY}}']) + chat_sep: Optional[Prompt] = None + auto_add_bos: bool = True + + +register_template(ChatmlTemplateMeta(LLMTemplateType.chatml)) +register_template(EmptyTemplateMeta(LLMTemplateType.dummy)) diff --git a/swift/llm/template/template/valley.py b/swift/llm/template/template/valley.py new file mode 100644 index 0000000000000000000000000000000000000000..ea075c995a3b674d5cdd0e557be9af1f25327790 --- /dev/null +++ b/swift/llm/template/template/valley.py @@ -0,0 +1,139 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import io +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional + +import torch +from PIL import Image + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context +from .utils import ChatmlTemplateMeta + + +@dataclass +class ValleyTemplateMeta(ChatmlTemplateMeta): + auto_add_bos: bool = False + default_system: Optional[str] = ('You are Valley, a large language and vision assistant trained by ByteDance.' + 'You are able to understand the visual content or video that the user provides,' + ' and assist the user with a variety of tasks using natural language.' + 'Follow the instructions carefully and explain your answers in detail.') + + +class ValleyTemplate(Template): + skip_prompt = True + use_model = True + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + # assert media_type == 'image' + if media_type == 'video': + from ..vision_utils import load_video_valley + return self.replace_video2image(load_video_valley, inputs, lambda i: [[151665, -200, 151666]]) + return [[151665, -200, 151666]] + + def preprocess_images(self, image_binary_list): + from valley_eagle.util.mm_utils import process_anyres_image + + def byte2image(byte_data): + return Image.open(io.BytesIO(byte_data)) + + images = [] + for binary in image_binary_list: + if isinstance(binary, Image.Image): + images.append(binary.convert('RGB')) + elif isinstance(binary, bytes): + images.append(byte2image(binary)) + else: + raise ValueError('unsupported type') + video_pad = [] + for img in images: + if self.model.config.anyres: + image = process_anyres_image(img, self.tokenizer.image_processor, self.model.config.grid_pinpoints) + else: + image = self.tokenizer.image_processor(img, return_tensors='pt')['pixel_values'][0] + video_pad.append(image) + + if not self.model.config.anyres: + video = torch.stack(video_pad, dim=0) + else: + video = [torch.stack(img, dim=0) for img in video_pad] + return video + + def process_images(self, inputs, images_binary): + import re + from qwen_vl_utils import fetch_image + + if inputs.messages[-1]['role'] == 'user': + text = inputs.messages[-1]['content'] + elif len(inputs.messages) > 1 and inputs.messages[-2]['role'] == 'user': + text = inputs.messages[-2]['content'] + else: + text = '' + video_images_tensor = self.preprocess_images(images_binary) + img_length = len(video_images_tensor) + video_images_tensor = [video_images_tensor] + if img_length: + images = [[item.to(self.model.dtype) for item in img] for img in video_images_tensor] + + messages_qwen = [] + image_list = [] + if isinstance(images_binary[0], Image.Image): + images_pil = [img.convert('RGB') for img in images_binary] + elif isinstance(images_binary[0], bytes): + images_pil = [Image.open(io.BytesIO(img)).convert('RGB') for img in images_binary] + image_sizes = torch.tensor([[x.size for x in images_pil]]) + for image_file in images_pil: + image = fetch_image({'image': image_file}) + image_list.append(image) + messages_qwen.append({'role': 'user', 'content': [{'type': 'text', 'text': text}]}) + messages_qwen.append({'role': 'assistant', 'content': [{'type': 'text', 'text': ''}]}) + text = self.tokenizer.qwen2vl_processor.apply_chat_template( + messages_qwen[:-1], tokenize=False, add_generation_prompt=True) + text_segs = re.split('', text) + text = '<|vision_start|><|image_pad|><|vision_end|>'.join(text_segs[:len(image_list) + 1]) + ''.join( + text_segs[len(image_list) + 1:]) + data_dict_qwen2vl = self.tokenizer.qwen2vl_processor( + text=[text], images=image_list, padding=True, return_tensors='pt') + results = {} + + results['images'] = images + results['image_sizes'] = image_sizes + results['pixel_values'] = data_dict_qwen2vl['pixel_values'] + results['image_grid_thw'] = data_dict_qwen2vl['image_grid_thw'] + return results + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + images = inputs.images or [] + input_ids = encoded['input_ids'] + labels = encoded['labels'] + if images: + results = self.process_images(inputs, images) + encoded['images'] = results['images'] + encoded['image_sizes'] = results['image_sizes'] + encoded['pixel_values'] = results['pixel_values'] + encoded['image_grid_thw'] = results['image_grid_thw'] + encoded['input_ids'] = input_ids + encoded['labels'] = labels + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + if 'images' in batch[0]: + res['images'] = sum([b['images'] for b in batch if 'images' in b], start=[]) + res['image_sizes'] = torch.concat([b['image_sizes'] for b in batch if 'image_sizes' in b], dim=0) + for media_type in ['image', 'video']: + grid_thw = [b[f'{media_type}_grid_thw'] for b in batch if b.get(f'{media_type}_grid_thw') is not None] + if grid_thw: + res[f'{media_type}_grid_thw'] = torch.concat(grid_thw) + return res + + +register_template(ValleyTemplateMeta( + MLLMTemplateType.valley, + template_cls=ValleyTemplate, +)) diff --git a/swift/llm/template/template/yi.py b/swift/llm/template/template/yi.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0424fe4a2c8cbe1bd0dd7341751048c2df4284 --- /dev/null +++ b/swift/llm/template/template/yi.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional + +import torch + +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta + +register_template(ChatmlTemplateMeta( + LLMTemplateType.yi_coder, + default_system=DEFAULT_SYSTEM, +)) + +yi_vl_default_system = ( + 'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. ' + "Read all the images carefully, and respond to the human's questions with informative, " + 'helpful, detailed and polite answers. ' + '这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。' + '仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。') + + +class YiVLTemplate(Template): + image_placeholder = [[-200], '\n'] + use_model = True + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + model = self.model + from llava.mm_utils import expand2square + if not hasattr(model, 'vision_tower'): + model = model.model + image_processor = model.vision_tower.image_processor + images = inputs.images or [] + for i, image in enumerate(images): + background_color = tuple(int(x * 255) for x in image_processor.image_mean) + image = expand2square(image, background_color) + images[i] = image + if images: + image_tensor = image_processor.preprocess(images, return_tensors='pt')['pixel_values'] + encoded['images'] = image_tensor.to(model.dtype) + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = super()._data_collator(batch, padding_to=padding_to) + images = [b['images'] for b in batch if 'images' in b] + if images: + res['images'] = torch.concat(images) + return res + + +register_template( + TemplateMeta( + MLLMTemplateType.yi_vl, + prefix=[], + prompt=[[8308], ' Human: {{QUERY}}\n', [8308], ' Assistant:'], + chat_sep=['\n'], + suffix=['\n', [8308]], + default_system=yi_vl_default_system, + template_cls=YiVLTemplate, + system_prefix=['{{SYSTEM}}\n\n'])) diff --git a/swift/llm/train/__init__.py b/swift/llm/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24b51f54449be443e5897c42acdb380475d27757 --- /dev/null +++ b/swift/llm/train/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .pt import SwiftPt, pt_main +from .rlhf import SwiftRLHF, rlhf_main +from .sft import SwiftSft, sft_main +from .tuner import get_multimodal_target_regex diff --git a/swift/llm/train/__pycache__/__init__.cpython-310.pyc b/swift/llm/train/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e017b6c5b262e358b1cd6f8b9e40cd8396e71478 Binary files /dev/null and b/swift/llm/train/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/callback.cpython-310.pyc b/swift/llm/train/__pycache__/callback.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da07c7738ea1b28e1cf7c6370a49ea41053e0593 Binary files /dev/null and b/swift/llm/train/__pycache__/callback.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/kto.cpython-310.pyc b/swift/llm/train/__pycache__/kto.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970644368049bc469cdfe0cfcb2f192471621355 Binary files /dev/null and b/swift/llm/train/__pycache__/kto.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/pt.cpython-310.pyc b/swift/llm/train/__pycache__/pt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d666970656e90be9eb0610cbf853e02a22f8aed Binary files /dev/null and b/swift/llm/train/__pycache__/pt.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/rlhf.cpython-310.pyc b/swift/llm/train/__pycache__/rlhf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793dc7c9bc86ac30c96d46bcb6242c3fce31d0a6 Binary files /dev/null and b/swift/llm/train/__pycache__/rlhf.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/sft.cpython-310.pyc b/swift/llm/train/__pycache__/sft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d8c6266241379e1fd1415ad0ea39f4b54ceb090 Binary files /dev/null and b/swift/llm/train/__pycache__/sft.cpython-310.pyc differ diff --git a/swift/llm/train/__pycache__/tuner.cpython-310.pyc b/swift/llm/train/__pycache__/tuner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f621f561cc3a30dbc495b9f5ccf91f738268b9 Binary files /dev/null and b/swift/llm/train/__pycache__/tuner.cpython-310.pyc differ diff --git a/swift/llm/train/callback.py b/swift/llm/train/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..2c466519b932dc843047d01988c8d5bc78a8da25 --- /dev/null +++ b/swift/llm/train/callback.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import types + +import numpy as np +import torch +from transformers import TrainerCallback + +from swift.utils import get_logger + +logger = get_logger() + + +class TrainerAdapterCallback(TrainerCallback): + + def __init__(self, args): + self.global_step = 0 + self.args = args + + # offload original_modules to cpu, to save memory + def on_train_begin(self, _args, state, control, **kwargs): + model = kwargs['model'] + if self.args.train_type == 'adalora': + model.peft_config['default'].total_step = state.max_steps + + def zero_grad(_self, *args, **kwargs): + _self.update_and_allocate(self.global_step + 1) + _self._zero_grad(*args, **kwargs) + + model._zero_grad = model.zero_grad + model.zero_grad = types.MethodType(zero_grad, model) + + def on_step_end(self, _args, state, control, **kwargs): + if self.args.train_type == 'adalora': + self.global_step = state.global_step + + +class DynamicLayerActivationCallback(TrainerCallback): + + def __init__(self, n_layers: int, step_interval: int, model: torch.nn.Module): + super().__init__() + self.n_layers = n_layers + self.step_interval = step_interval + self.model = model + layers_name = None + layers = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.ModuleList): + layers_name = name + layers = module + break + assert layers_name is not None + self.layers_attribute = layers_name + self.total_layers = len(layers) + + # Freeze all layers upon initialization + self.freeze_all_layers() + self.active_layers_indices = [] + + def freeze_all_layers(self): + layers = self.model.get_submodule(self.layers_attribute) + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def on_step_begin(self, args, state, control, **kwargs): + # Check if it's time to switch active layers, including at step 0 + if state.global_step % self.step_interval == 0 or state.global_step == 1: + self.switch_active_layers() + + def switch_active_layers(self): + # First, disable gradients for all layers + self.freeze_all_layers() + + # Randomly select n_layers to activate + layers = self.model.get_submodule(self.layers_attribute) + self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False) + # Enable gradients only for the selected layers + for idx in self.active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd319a62656f09fd6f6c3cb0949475f9afd9b5f --- /dev/null +++ b/swift/llm/train/kto.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import warnings +from typing import Any, Dict, Optional + +from datasets import Dataset as HfDataset + +from swift.utils import get_dist_setting, get_logger +from ..dataset import RowPreprocessor + +logger = get_logger() + + +class KTOPreprocessor(RowPreprocessor): + + def batched_preprocess(self, batched_row: Dict[str, Any], **kwargs) -> Dict[str, Any]: + batched_row = dict(batched_row) + messages = batched_row['messages'] + batch_size = len(messages) + kl_messages = [messages[-1]] + messages[:-1] + + kl_response = [] + for i in range(batch_size): + kl_message = kl_messages[i][-1] + assert kl_message['role'] == 'assistant' + kl_response.append(kl_message['content']) + # The name rejected_response is just for convenience in processing. + batched_row['rejected_response'] = kl_response + + return batched_row + + +def _get_kl_dataset(dataset: Optional[HfDataset], + total_batch_size: int, + num_proc: int, + seed: Optional[int] = None) -> Optional[HfDataset]: + # Shift one position to the right in each batch. + if dataset is None: + return + dataset = dataset.shuffle(seed) + return KTOPreprocessor()(dataset, batch_size=total_batch_size, num_proc=num_proc) + + +def prepare_kto_dataset(args, train_dataset, val_dataset): + world_size = get_dist_setting()[2] + total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) + if total_batch_size <= 1: + raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' + 'will be equivalent to the implied reward.') + train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + + label = train_dataset['label'] + num_desirable = max(sum(label), 1) + num_undesirable = max(len(label) - num_desirable, 1) # "label" is binary + + if num_desirable != num_undesirable: + # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306 + des_weight_lower_bound = round((num_undesirable * args.undesirable_weight / num_desirable) * 1, 2) + des_weight_upper_bound = round((num_undesirable * args.undesirable_weight / num_desirable) * 1.33, 2) + und_weight_lower_bound = round((num_desirable * args.desirable_weight / num_undesirable) / 1.33, 2) + und_weight_upper_bound = round((num_desirable * args.desirable_weight / num_undesirable) / 1, 2) + + des_weight_in_range = des_weight_lower_bound <= args.desirable_weight <= des_weight_upper_bound + und_weight_in_range = und_weight_lower_bound <= args.undesirable_weight <= und_weight_upper_bound + + if not (des_weight_in_range or und_weight_in_range): + logger.info(f'desirable_weight: {args.desirable_weight}, undesirable_weight: {args.undesirable_weight}') + warnings.warn( + f""" + You have different amounts of desirable/positive and undesirable/negative examples but the + weights on the desirable and undesirable losses don't seem to be in an ideal range. Based + on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, '{des_weight_upper_bound}] + or undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). + See the documentation on how to optimally set these weights.""", UserWarning) + return train_dataset, val_dataset diff --git a/swift/llm/train/pt.py b/swift/llm/train/pt.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed90a83a3ff031b68d7441ba3cd6915afd1e757 --- /dev/null +++ b/swift/llm/train/pt.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from ..argument import TrainArguments +from .sft import SwiftSft + + +class SwiftPt(SwiftSft): + args_class = TrainArguments + args: args_class + + def _prepare_template(self) -> None: + self.args.use_chat_template = False + super()._prepare_template() + self.template.loss_scale = 'all' + + +def pt_main(args: Union[List[str], TrainArguments, None] = None): + return SwiftPt(args).main() diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc7222599a7d3862658e201526d3943f78414ad --- /dev/null +++ b/swift/llm/train/rlhf.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List, Union + +from swift.llm import safe_snapshot_download +from swift.utils import get_logger, get_model_parameter_info +from ..argument import BaseArguments, RLHFArguments +from ..model import HfConfigFactory +from .kto import prepare_kto_dataset +from .sft import SwiftSft + +logger = get_logger() + + +class SwiftRLHF(SwiftSft): + args_class = RLHFArguments + args: args_class + + def _prepare_model_tokenizer(self): + if self.args.sequence_parallel_size > 1: + # Duplicate calling is allowd to promise this function will + # be called before model initializing. + from swift.trainers.sequence_parallel import sequence_parallel + sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size) + # prepare ref/reward/value model + from swift.llm.infer.utils import prepare_adapter + args = self.args + + def prepare_single_model(key, origin_key=None): + origin_key = origin_key or key + model_id_or_path = getattr(args, f'{key}_model') + if model_id_or_path is None: + return None + + model_type = getattr(args, f'{key}_model_type') + model_revision = getattr(args, f'{key}_model_revision') + model_dir = safe_snapshot_download( + model_id_or_path=model_id_or_path, + revision=model_revision, + download_model=False, + use_hf=args.use_hf, + hub_token=args.hub_token, + ) + task_type = None + num_labels = None + if os.path.exists(os.path.join(model_dir, 'args.json')): + model_args = BaseArguments.from_pretrained(model_dir) + if hasattr(model_args, 'task_type'): + task_type = model_args.task_type + else: + from transformers import AutoConfig + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + if hasattr(model_config, 'num_labels'): + num_labels = model_config.num_labels + if task_type == 'seq_cls': + num_labels = 1 + + model, processor = args.get_model_processor( + model=model_id_or_path, + model_type=model_type, + model_revision=model_revision, + task_type=task_type, + num_labels=num_labels) + + adapters = args.adapters if key == 'ref' else args.reward_adapters + model = prepare_adapter(args, model, adapters) + if origin_key in {'ref', 'reward'}: + if self.args.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + if hasattr(model, 'model_meta'): + is_multimodal = model.model_meta.is_multimodal + else: + is_multimodal = model.model.model_meta.is_multimodal + sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal) + model.requires_grad_(False).eval() + else: + model = self.prepare_model(args, model, task_type=task_type) + logger.info(f'value_model: {model}') + model_parameter_info = get_model_parameter_info(model) + self.train_msg['value_model_parameter_info'] = model_parameter_info + logger.info(f'value_model_parameter_info: {model_parameter_info}') + + HfConfigFactory.set_model_config_attr(model, 'use_cache', False) + return model, processor + + # Handle ref and value models + for key in ['ref', 'value']: + setattr(self, f'{key}_model', None) + if key == 'value' and args.rlhf_type != 'ppo': + continue + + model_key = 'reward' if key == 'value' else key + result = prepare_single_model(model_key, key) + if result is not None: + model, _ = result + setattr(self, f'{key}_model', model) + + # Handle reward model(s) + self.reward_model = None + if hasattr(args, 'reward_model') and args.reward_model is not None: + reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model] + self.reward_model = [] + if args.rlhf_type == 'grpo': + self.reward_template = [] + + for reward_model_path in reward_models: + args.reward_model = reward_model_path # Temporarily set for prepare_single_model + result = prepare_single_model('reward') + if result is not None: + model, processor = result + self.reward_model.append(model) + + if args.rlhf_type == 'grpo': + reward_template = self.args.get_template(processor, processor.model_meta.template) + if reward_template.use_model: + reward_template.model = model + self.reward_template.append(reward_template) + args.reward_model = reward_models # Restore original value + + super()._prepare_model_tokenizer() + + def _prepare_template(self) -> None: + args = self.args + super()._prepare_template() + model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'} + self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf')) + + if args.rlhf_type == 'ppo': + args.training_args.stop_token_id = self.template.template_meta.stop_token_id + + def _get_dataset(self): + args = self.args + train_dataset, val_dataset = super()._get_dataset() + if args.rlhf_type == 'kto': + train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset) + return train_dataset, val_dataset + + def _get_trainer_kwargs(self): + trainer_kwargs = {} + for key in ['ref', 'reward', 'value']: + key = f'{key}_model' + model = getattr(self, key, None) + if model or self.args.rlhf_type == 'ppo': + trainer_kwargs[key] = model + if hasattr(self, 'reward_template'): + trainer_kwargs['reward_template'] = self.reward_template + if self.args.rlhf_type == 'grpo': + trainer_kwargs['reward_funcs'] = self.args.reward_funcs + trainer_kwargs['vllm_client'] = self.args.vllm_client + return trainer_kwargs + + +def rlhf_main(args: Union[List[str], RLHFArguments, None] = None): + return SwiftRLHF(args).main() diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..6068aec234f07f96ac21fceb475e0c4702cff26b --- /dev/null +++ b/swift/llm/train/sft.py @@ -0,0 +1,287 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from functools import partial +from typing import List, Union + +from datasets import Dataset as HfDataset + +from swift.plugin import extra_callbacks, get_loss_func, get_metric +from swift.trainers import TrainerFactory +from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array, + use_torchacc) +from ..argument import TrainArguments +from ..base import SwiftPipeline +from ..dataset import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset, + PackingDataset, load_dataset) +from ..infer import prepare_generation_config +from ..model import HfConfigFactory, get_model_arch +from ..utils import deep_getattr, dynamic_gradient_checkpointing +from .tuner import TunerMixin + +logger = get_logger() + + +class SwiftSft(SwiftPipeline, TunerMixin): + args_class = TrainArguments + args: args_class + + def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: + super().__init__(args) + self.train_msg = {} + self._prepare_model_tokenizer() + self._prepare_template() + self._prepare_callbacks() + + def _prepare_gradient_checkpointing(self): + args = self.args + HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False) + if args.gradient_checkpointing: + self.model.supports_gradient_checkpointing = True + dynamic_gradient_checkpointing(self.model) + self.model.enable_input_require_grads() + model_meta = self.model.model_meta + model_arch = get_model_arch(model_meta.model_arch) + if model_meta.is_multimodal and model_arch: + for vision_tower_name in model_arch.vision_tower: + vision_tower = deep_getattr(self.model, vision_tower_name) + if hasattr(vision_tower, 'enable_input_require_grads'): + try: + vision_tower.enable_input_require_grads() + except NotImplementedError: + pass + + def _prepare_generation_config(self): + args = self.args + self.model.origin_generation_config = self.model.generation_config + self.model.generation_config = prepare_generation_config(self.model.generation_config, + args.get_request_config(), self.tokenizer) + logger.info(f'model.generation_config: {self.model.generation_config}') + + def _prepare_model_tokenizer(self): + args = self.args + if args.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + sequence_parallel.init_sequence_parallel(args.sequence_parallel_size) + self.model, self.processor = args.get_model_processor() + + if hasattr(self.model, 'hf_device_map'): + logger.info(f'model.hf_device_map: {self.model.hf_device_map}') + + logger.info(f'model_info: {self.model.model_info}') + + self._prepare_generation_config() + self._prepare_gradient_checkpointing() + + def _prepare_template(self) -> None: + template = self.args.get_template(self.processor) + if self.args.task_type == 'causal_lm': + template.set_mode('train') + if template.use_model: + template.model = self.model + self.template = template + + def _get_dataset(self): + # The random shuffling of the training set occurs in the dataloader of the trainer. + args = self.args + dataset_kwargs = args.get_dataset_kwargs() + train_dataset, val_dataset = load_dataset( + args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) + if len(args.val_dataset) > 0: + # Loading val dataset + _, val_dataset = load_dataset( + args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) + assert args.split_dataset_ratio == 0. + logger.info(f'train_dataset: {train_dataset}') + logger.info(f'val_dataset: {val_dataset}') + + return train_dataset, val_dataset + + def _get_loss_func(self): + args = self.args + loss_type = args.loss_type + if loss_type is None and args.loss_scale != 'default': + loss_type = 'loss_scale' + return get_loss_func(loss_type) + + def _get_data_collator(self): + args = self.args + template = self.template + padding_to = args.max_length if args.train_type == 'longlora' else None + return partial(template.data_collator, padding_to=padding_to) + + @staticmethod + def _save_val_dataset(output_dir: str, val_dataset): + if is_master() and isinstance(val_dataset, HfDataset): + os.makedirs(output_dir, exist_ok=True) + val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl') + append_to_jsonl(val_dataset_path, val_dataset.to_list()) + logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') + + def run(self): + args = self.args + + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + + if args.task_type == 'seq_cls': + args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) + logger.info(f'args.problem_type: {args.problem_type}') + args.save_args() + + data_collator = self._get_data_collator() + # Some tuners require train_dataset and data_collator for preparation: LoRA-GA + self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) + logger.info(f'model: {self.model}') + model_parameter_info = get_model_parameter_info(self.model) + self.train_msg['model_parameter_info'] = model_parameter_info + logger.info(f'model_parameter_info: {model_parameter_info}') + + trainer_cls = TrainerFactory.get_trainer_cls(args) + trainer = trainer_cls( + model=self.model, + args=self.args.training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=self.callbacks, + template=self.template, + **self._get_trainer_kwargs(), + ) + return self.train(trainer) + + def _get_trainer_kwargs(self): + args = self.args + if args.metric is not None: + compute_metrics, preprocess_logits_for_metrics = get_metric(args.metric) + elif args.predict_with_generate: + compute_metrics, preprocess_logits_for_metrics = get_metric('nlg') + else: + compute_metrics, preprocess_logits_for_metrics = get_metric('acc') + compute_metrics = partial( + compute_metrics, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder) + return { + 'compute_metrics': compute_metrics, + 'preprocess_logits_for_metrics': preprocess_logits_for_metrics, + 'compute_loss_func': self._get_loss_func() + } + + def _save_trainer_state(self, trainer): + training_args = trainer.args + state = trainer.state + if hasattr(state, 'last_model_checkpoint'): + if self.args.create_checkpoint_symlink: + last_checkpoint = os.path.join(self.args.output_dir, 'last') + best_checkpoint = os.path.join(self.args.output_dir, 'best') + os.symlink(state.last_model_checkpoint, last_checkpoint) + os.symlink(state.best_model_checkpoint, best_checkpoint) + state.last_model_checkpoint = last_checkpoint + state.best_model_checkpoint = best_checkpoint + else: + state.last_model_checkpoint = None + logger.warning('No training was carried out, which may be due to the dataset being too small ' + 'or incorrect usage of resume_from_checkpoint.') + logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}') + logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}') + + # Visualization + if is_master() and not use_torchacc(): + if 'tensorboard' in training_args.report_to: + images_dir = os.path.join(training_args.output_dir, 'images') + logger.info(f'images_dir: {images_dir}') + plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9) + if training_args.push_to_hub: + trainer.push_to_hub() + + self.train_msg.update({ + 'last_model_checkpoint': state.last_model_checkpoint, + 'best_model_checkpoint': state.best_model_checkpoint, + 'best_metric': state.best_metric, + 'global_step': state.global_step, + 'log_history': state.log_history, + 'memory': trainer.max_memory, + }) + if is_master(): + jsonl_path = os.path.join(training_args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, self.train_msg) + return self.train_msg + + def train(self, trainer): + logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl') + logger.info(f'The logging file will be saved in: {logging_path}') + try: + trainer.train(trainer.args.resume_from_checkpoint) + finally: + res = self._save_trainer_state(trainer) + return res + + def _prepare_callbacks(self): + from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback + args = self.args + callbacks = [] + if args.lisa_activated_layers > 0: + assert args.train_type == 'full', 'LISA only supports full parameter training.' + lisa_callback = DynamicLayerActivationCallback( + n_layers=args.lisa_activated_layers, # Number of layers to activate + step_interval=args.lisa_step_interval, # Step interval to update active layers + model=self.model) + lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value + callbacks.append(lisa_callback) + + if args.is_adapter and args.train_type == 'adalora': + callbacks.append(TrainerAdapterCallback(args)) + callbacks += extra_callbacks + self.callbacks = callbacks + + def _stat_dataset(self, dataset: HfDataset): + args = self.args + if isinstance(dataset, HfDataset): + dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc) + length = dataset['length'] + else: + length = [] + for row in dataset: + length.append(max([len(row[k]) for k in row.keys() if k.endswith('input_ids')])) + _, stat_str = stat_array(length) + logger.info(f'Dataset Token Length: {stat_str}') + return stat_str + + def _encode_dataset(self, train_dataset, val_dataset): + template = self.template + args = self.args + output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save') + self._save_val_dataset(output_dir, val_dataset) + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + predict_with_generate = getattr(args, 'predict_with_generate', False) + if not is_grpo: + if args.packing: + packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset + train_dataset = packing_dataset_cls( + self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) + if val_dataset is not None: + val_dataset = packing_dataset_cls( + self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) + elif args.lazy_tokenize: + train_dataset = LazyLLMDataset( + train_dataset, template.encode, strict=args.strict, random_state=args.data_seed) + if val_dataset is not None and not predict_with_generate: + val_dataset = LazyLLMDataset( + val_dataset, template.encode, strict=args.strict, random_state=args.data_seed) + else: + preprocessor = EncodePreprocessor(template=template) + train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) + if val_dataset is not None and not predict_with_generate: + val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) + + if is_master(): + inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) + template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) + if isinstance(train_dataset, (HfDataset, PackingDataset)): + self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) + if val_dataset is not None and not predict_with_generate: + self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) + + return train_dataset, val_dataset + + +def sft_main(args: Union[List[str], TrainArguments, None] = None): + return SwiftSft(args).main() diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..531e98a2cd6a5ce76764b616a35b8ec62f7c9c78 --- /dev/null +++ b/swift/llm/train/tuner.py @@ -0,0 +1,424 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +import os +from typing import List, Union + +import torch +import torch.nn as nn +import transformers +from packaging import version +from transformers import TrainingArguments + +from swift.llm import TrainArguments, deep_getattr, get_model_arch +from swift.plugin import Tuner, extra_tuners +from swift.tuners import Swift +from swift.utils import (activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters, + get_logger, use_torchacc) + +logger = get_logger() + + +def apply_liger(model_type: str): + from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, + apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma, + apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen3, + apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl, + apply_liger_kernel_to_phi3, apply_liger_kernel_to_mllama) + from swift.llm import ModelType + if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2): + apply_liger_kernel_to_llama() + elif model_type in (ModelType.mistral): + apply_liger_kernel_to_mistral() + elif model_type in (ModelType.mixtral): + apply_liger_kernel_to_mixtral() + elif model_type in (ModelType.gemma, ModelType.gemma2): + apply_liger_kernel_to_gemma() + elif model_type in (ModelType.qwen2, ModelType.qwen2_5): + apply_liger_kernel_to_qwen2() + elif model_type in (ModelType.qwen3): + apply_liger_kernel_to_qwen3() + elif model_type in (ModelType.phi3): + apply_liger_kernel_to_phi3() + elif model_type in (ModelType.llama3_2_vision): + apply_liger_kernel_to_mllama() + elif model_type in (ModelType.qwen2_vl): + apply_liger_kernel_to_qwen2_vl() + elif model_type in (ModelType.qwen2_5_vl): + apply_liger_kernel_to_qwen2_5_vl() + else: + raise ValueError(f'Unsupported liger model_type: {model_type}') + + +def get_multimodal_target_regex( + model, + *, + freeze_llm: bool = False, + freeze_vit: bool = True, + freeze_aligner: bool = True, + include_embedding: bool = False, +) -> str: + model_arch = get_model_arch(model.model_meta.model_arch) + modules = [] + if not freeze_llm: + modules += model_arch.language_model + if not freeze_vit: + modules += model_arch.vision_tower + if not freeze_aligner: + modules += model_arch.aligner + assert len(modules) > 0, f'modules: {modules}' + + extra_layers = [] + if include_embedding: + extra_layers.append(nn.Embedding) + res = [] + for module in modules: + rejected_modules = [] + if not freeze_vit: + for aligner in model_arch.aligner: + if aligner.startswith(f'{module}.'): + rejected_modules.append(aligner) + + sub_module = deep_getattr(model, module) + target_modules = find_all_linears(sub_module, model_arch, extra_layers) + target_modules = [tm for tm in target_modules if tm] + target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' + rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' + res.append(rf'{rejected_pattern}{module}{target_pattern}') + + return rf'^({"|".join(res)})$' + + +def get_target_modules(args, model) -> Union[str, List[str]]: + """Replace all-linear to actual modules""" + model_meta = model.model_meta + if isinstance(args.target_modules, str): + return args.target_modules + target_modules = args.target_modules.copy() + if 'all-linear' in target_modules: + if model_meta.is_multimodal: + return get_multimodal_target_regex( + model, + freeze_llm=args.freeze_llm, + freeze_vit=args.freeze_vit, + freeze_aligner=args.freeze_aligner, + include_embedding='all-embedding' in target_modules) + else: + target_modules.remove('all-linear') + target_modules += find_all_linears(model) + if 'all-embedding' in target_modules: + target_modules.remove('all-embedding') + target_modules += find_embedding(model) + return target_modules + + +def get_modules_to_save(args, model, task_type=None): + modules_to_save = args.modules_to_save.copy() + if 'all-embedding' in args.modules_to_save: + modules_to_save.remove('all-embedding') + modules_to_save += find_embedding(model) + if 'all-norm' in args.modules_to_save: + modules_to_save.remove('all-norm') + modules_to_save += find_norm(model) + if task_type and task_type.lower() == 'seq_cls': # reward_model + modules_to_save.append('v_head') + return modules_to_save + + +def get_vera_target_modules(model, config): + """This function is only useful on the vera tuner""" + target_modules = config.target_modules + modules_dict = { + name: module.weight.shape + for name, module in model.named_modules() + if isinstance(module, torch.nn.Linear) and any([t in name for t in target_modules]) + } # only Linear for now + if len(set(modules_dict.values())) > 1: + v = [t for t in target_modules if 'v' in t] + if not v: + raise ValueError('Please manually pass in `vera_target_modules`, do not use `all-linear`,' + 'because Vera need all target linears to be the same size.') + v = v[0] + shape = [shape for name, shape in modules_dict.items() if v in name][0] + names = [_name for _name, _shape in modules_dict.items() if _shape == shape] + config.target_modules = [t for t in target_modules if any([t in name for name in names])] + return config + + +def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None, task_type=None): + from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, LLaMAProConfig, LongLoRAModelType, LoraConfig, + LoRAConfig, ReftConfig, Swift, VeraConfig) + task_type = (task_type or args.task_type).upper() + target_modules = get_target_modules(args, model) + modules_to_save = get_modules_to_save(args, model, task_type) + lora_kwargs = { + 'r': args.lora_rank, + 'target_modules': target_modules, + 'lora_alpha': args.lora_alpha, + 'lora_dropout': args.lora_dropout, + 'bias': args.lora_bias, + 'modules_to_save': modules_to_save, + 'use_rslora': args.use_rslora, + 'use_dora': args.use_dora, + 'lorap_lr_ratio': args.lorap_lr_ratio, + 'init_lora_weights': args.init_weights, + } + if args.train_type in ('lora', 'longlora'): + if args.use_swift_lora: + lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs) + model = Swift.prepare_model(model, lora_config) + logger.info(f'lora_config: {lora_config}') + elif args.tuner_backend == 'peft': + if task_type == 'EMBEDDING': + task_type = None + lora_config = LoraConfig(task_type=task_type, lora_dtype=args.lora_dtype, **lora_kwargs) + if args.init_weights == 'lora-ga': + try: + import lora_ga + except ImportError as e: + error_message = """ + Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub. + Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'. + """ + logger.info(error_message) + raise RuntimeError(error_message) from e + model = lora_ga.entrypoint.get_lora_ga_model( + model=model, + data_collator=template.data_collator, + dataset=train_dataset, + batch_size=args.lora_ga_batch_size, + num_iters=args.lora_ga_iters, + max_length=args.lora_ga_max_length, + direction=args.lora_ga_direction, + dtype=args.lora_dtype, + scale=args.lora_ga_scale, + stable_gamma=args.lora_ga_stable_gamma, + ) + else: + model = Swift.prepare_model(model, lora_config) + logger.info(f'lora_config: {lora_config}') + elif args.tuner_backend == 'unsloth': + if args.resume_from_checkpoint is None: + if args.model_meta.is_multimodal: + from unsloth import FastVisionModel as UnslothModel + else: + from unsloth import FastLanguageModel as UnslothModel + assert args.train_type == 'lora', 'Unsloth does not support LongLoRA' + lora_kwargs.pop('lorap_lr_ratio') + model = UnslothModel.get_peft_model( + model, + use_gradient_checkpointing='unsloth', + max_seq_length=args.max_length or 2048, # 2048 is the default value of unsloth + **lora_kwargs, + ) + logger.info(f'unsloth_config: {lora_kwargs}') + if args.train_type == 'longlora': + assert LongLoRAModelType.LLAMA in args.model_type + assert version.parse(transformers.__version__) >= version.parse('4.39.3') + from swift.tuners.longlora.llama import replace_llama_attn + replace_llama_attn(model) + model.config.group_size_ratio = 0.25 + elif args.train_type == 'adalora': + lora_kwargs.pop('lorap_lr_ratio', None) + lora_kwargs['rank_pattern'] = None + from swift.plugin.optimizer import calculate_max_steps + adalora_config = AdaLoraConfig( + task_type=task_type, + **lora_kwargs, + target_r=args.adalora_target_r, + init_r=args.adalora_init_r, + tinit=args.adalora_tinit, + tfinal=args.adalora_tfinal, + deltaT=args.adalora_deltaT, + beta1=args.adalora_beta1, + beta2=args.adalora_beta2, + orth_reg_weight=args.adalora_orth_reg_weight, + total_step=calculate_max_steps(args.training_args, train_dataset), + ) + model = Swift.prepare_model(model, adalora_config) + logger.info(f'adalora_config: {adalora_config}') + elif args.train_type == 'llamapro': + llamapro_config = LLaMAProConfig( + model_type=model.model_meta.model_arch, + num_new_blocks=args.llamapro_num_new_blocks, + num_groups=args.llamapro_num_groups) + model = Swift.prepare_model(model, llamapro_config) + logger.info(f'llamapro_config: {llamapro_config}') + elif args.train_type == 'adapter': + model_arch = get_model_arch(model.model_meta.model_arch) + mlp_key = model_arch.mlp + mlp_key = mlp_key.split('.{}.')[1] + adapter_config = AdapterConfig( + dim=model.config.hidden_size, + target_modules=[mlp_key], + hidden_pos=0, + adapter_length=args.adapter_length, + act_layer=args.adapter_act) + model = Swift.prepare_model(model, adapter_config) + logger.info(f'adapter_config: {adapter_config}') + elif args.train_type == 'vera': + vera_config = VeraConfig( + r=args.vera_rank, + target_modules=target_modules, + projection_prng_key=args.vera_projection_prng_key, + vera_dropout=args.vera_dropout, + d_initial=args.vera_d_initial, + modules_to_save=args.modules_to_save, + ) + vera_config = get_vera_target_modules(model, vera_config) + model = Swift.prepare_model(model, vera_config) + logger.info(f'vera_config: {vera_config}') + elif args.train_type == 'boft': + boft_config = BOFTConfig( + boft_block_size=args.boft_block_size, + boft_block_num=args.boft_block_num, + boft_n_butterfly_factor=args.boft_n_butterfly_factor, + target_modules=target_modules, + boft_dropout=args.boft_dropout, + modules_to_save=args.modules_to_save, + ) + model = Swift.prepare_model(model, boft_config) + logger.info(f'boft_config: {boft_config}') + elif args.train_type == 'fourierft': + from peft import FourierFTConfig + fourier_config = FourierFTConfig( + target_modules=target_modules, + modules_to_save=args.modules_to_save, + n_frequency=args.fourier_n_frequency, + scaling=args.fourier_scaling, + ) + model = Swift.prepare_model(model, fourier_config) + logger.info(f'fourier_config: {fourier_config}') + elif args.train_type == 'reft': + reft_config = ReftConfig( + model_type=model.model_meta.model_arch, + layer_key=args.reft_layer_key, + r=args.reft_rank, + layers=args.reft_layers, + intervention_type=args.reft_intervention_type, + args=args.reft_args, + ) + logger.info(f'reft config: {reft_config}') + model = Swift.prepare_model(model, {'reft': reft_config}) + elif args.train_type == 'bone': + # Version loosing + from peft import BoneConfig + bone_config = BoneConfig( + target_modules=target_modules, + r=args.reft_rank, + init_weights=args.init_weights, + ) + logger.info(f'bone config: {bone_config}') + model = Swift.prepare_model(model, bone_config) + return model + + +def torchacc_resume_from_checkpoint(args, model): + import safetensors + weights_file = os.path.join(args.resume_from_checkpoint, 'pytorch_model.bin') + safe_weights_file = os.path.join(args.resume_from_checkpoint, 'model.safetensors') + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): + if args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device='cpu') + else: + state_dict = torch.load(weights_file, map_location='cpu') + model.load_state_dict(state_dict, False) + del state_dict + else: + from transformers.modeling_utils import load_sharded_checkpoint + # We load the sharded checkpoint + load_result = load_sharded_checkpoint( + model, args.resume_from_checkpoint, strict=False, prefer_safe=args.save_safetensors) + if len(load_result.missing_keys) != 0: + if model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + model._keys_to_ignore_on_save): + model.tie_weights() + else: + logger.warning(f'There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.') + if len(load_result.unexpected_keys) != 0: + logger.warning(f'There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.') + + +class TunerMixin: + + @classmethod + def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): + if args.use_liger_kernel and 'use_liger_kernel' not in inspect.signature(TrainingArguments).parameters: + # Apply liger + apply_liger(args.model_type) + + if args.is_adapter: + if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners: + # Fix the name of the layer in xcomposer that contains Plora. + # Unsloth prepares and loads lora outside this function when + # resume_from_checkpoint, so do not disable grad here + model.requires_grad_(False) + if args.resume_from_checkpoint: + if args.train_type in extra_tuners: + tuner: Tuner = extra_tuners[args.train_type] + else: + tuner = Swift + kwargs = {} + if use_torchacc(): + kwargs = {'adapter_name': 'default'} + model = tuner.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True, **kwargs) + else: + if args.train_type in extra_tuners: + tuner: Tuner = extra_tuners[args.train_type] + model = tuner.prepare_model(args, model) + else: + model = prepare_adapter( + args, model, template=template, train_dataset=train_dataset, task_type=task_type) + # fix bug: Attempting to unscale FP16 gradients. + # peft: https://github.com/huggingface/peft/issues/1249 + for p in model.parameters(): + if p.requires_grad and p.dtype == torch.float16: + logger.info_once('Convert trainable parameters from fp16 to fp32.') + p.data = p.data.to(dtype=torch.float32) + elif args.train_type == 'full': + model.train() + model.requires_grad_(True) + + freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) + if len(args.trainable_parameters) > 0 or args.trainable_parameters_regex is not None: + activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex) + if use_torchacc() and args.resume_from_checkpoint: + torchacc_resume_from_checkpoint(args, model) + else: + raise ValueError(f'args.train_type: {args.train_type}') + + if args.resume_only_model: + args.training_args.resume_from_checkpoint = None + if args.use_galore: + from swift.trainers.optimizers.galore import GaLoreConfig + if args.galore_target_modules is None: + args.galore_target_modules = find_all_linears(model) + if args.galore_with_embedding: + args.galore_target_modules += find_embedding(model) + args.galore_config = GaLoreConfig( + target_modules=args.galore_target_modules, + rank=args.galore_rank, + update_proj_gap=args.galore_update_proj_gap, + galore_scale=args.galore_scale, + proj_type=args.galore_proj_type, + optim_per_parameter=args.galore_optim_per_parameter, + quantize=args.galore_quantization, + proj_quant=args.galore_proj_quant, + proj_bits=args.galore_proj_bits, + proj_group_size=args.galore_proj_group_size, + cos_threshold=args.galore_cos_threshold, + gamma_proj=args.galore_gamma_proj, + queue_size=args.galore_queue_size, + ) + args.training_args.galore_config = args.galore_config + + if args.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + if hasattr(model, 'model_meta'): + is_multimodal = model.model_meta.is_multimodal + else: + is_multimodal = model.model.model_meta.is_multimodal + # multimodal model must do split in basemodel's forward + # or the media embedding may occur error + sequence_parallel.prepare_model(model, template.tokenizer, split_in_forward=is_multimodal) + + return model diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7c6b4060af72fae421277ab7dd6932947b051c --- /dev/null +++ b/swift/megatron/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +try: + from .init import init_megatron_env + init_megatron_env() +except Exception: + # allows lint pass. + raise + +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .train import megatron_sft_main, megatron_pt_main + from .utils import convert_hf2mcore, convert_mcore2hf + from .argument import MegatronTrainArguments + from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model +else: + _import_structure = { + 'train': ['megatron_sft_main', 'megatron_pt_main'], + 'utils': ['convert_hf2mcore', 'convert_mcore2hf'], + 'argument': ['MegatronTrainArguments'], + 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'] + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/megatron/argument/__init__.py b/swift/megatron/argument/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..032d3c471b46f0406c7512af2414e27063f5ba71 --- /dev/null +++ b/swift/megatron/argument/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .megatron_args import MegatronArguments +from .train_args import MegatronTrainArguments diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py new file mode 100644 index 0000000000000000000000000000000000000000..90309ff114a211f1f7681c4fa53407b85e89cd69 --- /dev/null +++ b/swift/megatron/argument/megatron_args.py @@ -0,0 +1,253 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +from transformers.utils.versions import require_version + +from swift.llm.argument.base_args import to_abspath + + +@dataclass +class ExtraMegatronArguments: + padded_vocab_size: Optional[int] = None + rope_scaling: Optional[Union[dict, str]] = None + torch_dtype: Optional[torch.dtype] = None + + dataloader_persistent_workers: bool = True + dataloader_prefetch_factor: int = 10 + + model_type: Optional[str] = None + max_epochs: Optional[int] = None + + +@dataclass +class MegatronArguments(ExtraMegatronArguments): + # training + micro_batch_size: int = 1 + global_batch_size: int = 16 + recompute_granularity: Literal['selective', 'full'] = 'selective' + recompute_method: Literal['uniform', 'block'] = None + recompute_num_layers: Optional[int] = None + recompute_modules: List[str] = field(default_factory=lambda: ['core_attn']) + use_cpu_initialization: bool = False + deterministic_mode: bool = False + train_iters: Optional[int] = None + log_interval: int = 5 + tensorboard_dir: Optional[str] = None + no_masked_softmax_fusion: bool = False + no_bias_dropout_fusion: bool = False + no_bias_swiglu_fusion: bool = False + no_rope_fusion: bool = False + no_gradient_accumulation_fusion: bool = False + cross_entropy_loss_fusion: bool = False + calculate_per_token_loss: bool = True + use_flash_attn: bool = False + attention_backend: str = 'auto' # flash, fused, unfused, local, auto + optimizer: Literal['adam', 'sgd'] = 'adam' + dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic' + manual_gc: bool = False + manual_gc_interval: int = 0 + + # learning rate + lr: float = 1e-5 + lr_decay_style: Literal['cosine', 'linear', 'constant'] = 'cosine' + # The default is None, which will be set to `train_iters`. + lr_decay_iters: Optional[int] = None + lr_warmup_iters: int = 0 + min_lr: float = 0 + + # regularization + weight_decay: float = 0.1 + clip_grad: float = 1. + adam_beta1: float = 0.9 + adam_beta2: float = 0.95 + adam_eps: float = 1e-8 + sgd_momentum: float = 0.9 + + # checkpoint + save: Optional[str] = None + save_interval: int = 500 + no_save_optim: bool = False + no_save_rng: bool = False + load: Optional[str] = None + no_load_optim: bool = False + no_load_rng: bool = False + finetune: bool = False + ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' + no_initialization: bool = True + auto_detect_ckpt_format: bool = True + exit_on_missing_checkpoint: bool = True + + # dist + distributed_backend: Literal['nccl', 'gloo'] = 'nccl' + use_distributed_optimizer: bool = True + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + decoder_first_pipeline_num_layers: Optional[int] = None + decoder_last_pipeline_num_layers: Optional[int] = None + sequence_parallel: bool = False + context_parallel_size: int = 1 + tp_comm_overlap: bool = False + overlap_grad_reduce: bool = False + overlap_param_gather: bool = False + distributed_timeout_minutes: int = 60 + + # model + num_layers: Optional[int] = None + hidden_size: Optional[int] = None + ffn_hidden_size: Optional[int] = None + num_attention_heads: Optional[int] = None + group_query_attention: Optional[bool] = None + num_query_groups: Optional[int] = None + max_position_embeddings: Optional[int] = None + position_embedding_type: Literal['learned_absolute', 'rope', 'relative', 'none'] = 'rope' + rotary_base: Optional[int] = None + rotary_percent: float = 1. + normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' + norm_epsilon: Optional[float] = None + swiglu: Optional[bool] = None + untie_embeddings_and_output_weights: Optional[bool] = None + disable_bias_linear: Optional[bool] = None + add_qkv_bias: Optional[bool] = None + attention_dropout: Optional[float] = None + hidden_dropout: float = 0. + kv_channels: Optional[int] = None + qk_layernorm: Optional[bool] = None + transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine' + + # moe + num_experts: Optional[int] = None + moe_ffn_hidden_size: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_router_topk: Optional[int] = None + moe_router_pre_softmax: Optional[bool] = None + moe_aux_loss_coeff: Optional[float] = None + + expert_model_parallel_size: int = 1 + moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'alltoall_seq'] = 'alltoall' + moe_grouped_gemm: bool = False + moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = 'aux_loss' + moe_z_loss_coeff: Optional[float] = None + moe_expert_capacity_factor: Optional[float] = None + moe_shared_expert_overlap: bool = False + + # mixed precision + fp16: Optional[bool] = None + bf16: Optional[bool] = None + apply_query_key_layer_scaling: Optional[bool] = None + attention_softmax_in_fp32: bool = True + + # logging + log_params_norm: bool = False + log_throughput: bool = True + tensorboard_log_interval: int = 1 + tensorboard_queue_size: int = 50 + log_timers_to_tensorboard: bool = True + no_log_learning_rate_to_tensorboard: bool = False + log_validation_ppl_to_tensorboard: bool = True + log_memory_to_tensorboard: bool = True + logging_level: Optional[str] = None + wandb_project: Optional[str] = None + wandb_exp_name: Optional[str] = None + wandb_save_dir: Optional[str] = None + + # evaluate + eval_iters: int = 100 + eval_interval: Optional[int] = None + + # other + seed: int = 42 + seq_length: Optional[int] = None + num_workers: int = 4 + no_create_attention_mask_in_dataloader: bool = True + + def _set_default(self): + if self.num_query_groups is None: + self.num_query_groups = 1 + if self.norm_epsilon is None: + self.norm_epsilon = 1e-5 + if self.rotary_base is None: + self.rotary_base = 10000 + if self.attention_dropout is None: + self.attention_dropout = 0. + if self.untie_embeddings_and_output_weights is None: + self.untie_embeddings_and_output_weights = True + if self.swiglu is None: + self.swiglu = True + if self.add_qkv_bias is None: + self.add_qkv_bias = True + if self.disable_bias_linear is None: + self.disable_bias_linear = True + if self.moe_router_topk is None: + self.moe_router_topk = 2 + if self.moe_router_pre_softmax is None: + self.moe_router_pre_softmax = False + if self.moe_aux_loss_coeff is None: + self.moe_aux_loss_coeff = 0. + if self.qk_layernorm is None: + self.qk_layernorm = False + + def _init_mixed_precision(self): + from swift.llm.argument.base_args.model_args import ModelArguments + ModelArguments._init_mixed_precision(self) + if self.apply_query_key_layer_scaling is None: + self.apply_query_key_layer_scaling = self.fp16 + if self.apply_query_key_layer_scaling: + os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' + + def _init_moe(self): + if self.moe_shared_expert_intermediate_size == 0: + self.moe_shared_expert_intermediate_size = None + if self.moe_ffn_hidden_size is None: + self.moe_ffn_hidden_size = self.ffn_hidden_size + else: + self.ffn_hidden_size = self.moe_ffn_hidden_size + + def __post_init__(self): + from swift.llm.argument.base_args.model_args import ModelArguments + if self.use_flash_attn or self.attention_backend == 'flash': + require_version('flash-attn') + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + self._set_default() + self.group_query_attention = self.num_query_groups > 1 + if self.rope_scaling is not None: + self.rope_scaling = ModelArguments.parse_to_dict(self.rope_scaling) + if self.eval_interval is None: + self.eval_interval = self.save_interval + if self.seq_length is None: + self.seq_length = self.max_position_embeddings + if self.tensorboard_dir is None and self.save is not None: + self.tensorboard_dir = f'{self.save}/runs' + self._init_moe() + self._init_mixed_precision() + + self.tensorboard_dir = to_abspath(self.tensorboard_dir) + + def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: + new_args = [] + args_dict = asdict(self) + extra_args = {} + for k, value in args_dict.items(): + if k not in MegatronArguments.__annotations__: + extra_args[k] = value + continue + if value is None or value is False: + continue + new_args.append(f"--{k.replace('_', '-')}") + if isinstance(value, list): + new_args += [str(v) for v in value] + elif value is not True: + new_args.append(str(value)) + + return new_args, extra_args + + def parse_to_megatron(self): + new_args, extra_args = self._args_to_argv() + sys._old_argv = sys.argv + sys.argv = sys.argv[:1] + new_args + # parameter conflict + extra_args.pop('loss_scale', None) + return extra_args diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py new file mode 100644 index 0000000000000000000000000000000000000000..c43b5e8f76c3e38bbbbc6083067bc4d44deaa281 --- /dev/null +++ b/swift/megatron/argument/train_args.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass + +import torch + +from swift.llm import BaseArguments +from swift.llm.argument.base_args import to_abspath +from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master +from ..model import get_megatron_model_meta +from .megatron_args import MegatronArguments + +logger = get_logger() + + +@dataclass +class MegatronTrainArguments(MegatronArguments, BaseArguments): + add_version: bool = True + # dataset + lazy_tokenize: bool = False + packing: bool = False + + def init_model_args(self, config): + self.megatron_model_meta = get_megatron_model_meta(self.model_type) + kwargs = self.megatron_model_meta.convert_hf_config(config) + for k, v in kwargs.items(): + if getattr(self, k) is None: + setattr(self, k, v) + MegatronArguments.__post_init__(self) + self.extra_args = self.parse_to_megatron() + + def _init_save(self): + init_process_group() + if self.save is None: + self.save = f'megatron_output/{self.model_suffix}' + self.save = to_abspath(self.save) + if self.add_version: + self.save = add_version_to_work_dir(self.save) + logger.info(f'args.save: {self.save}') + if is_master(): + os.makedirs(self.save, exist_ok=True) + + def __post_init__(self): + self.sequence_parallel_size = self.context_parallel_size + self.load = to_abspath(self.load, check_path_exist=True) + BaseArguments.__post_init__(self) + self._init_save() + self.seq_length = self.seq_length or self.max_length + if self.streaming: + self.dataloader_type = 'external' + if self.num_workers > 1: + self.num_workers = 1 + logger.info('Using streaming dataset, setting args.num_workers to 1.') diff --git a/swift/megatron/init.py b/swift/megatron/init.py new file mode 100644 index 0000000000000000000000000000000000000000..72380c414a95364e32f199fa0556bc0f4283036e --- /dev/null +++ b/swift/megatron/init.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +from contextlib import contextmanager + +from swift.llm import git_clone_github +from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run + +logger = get_logger() + + +def _patch_transformer_engine(): + try: + from transformer_engine.pytorch.attention import FusedRoPEFunc + except ImportError: + try: + import transformer_engine + transformer_engine.pytorch.attention.FusedRoPEFunc = ( + transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc) + except (ImportError, AttributeError): + pass + + +def new_cyclic_iter(iter): + from megatron.training import get_args + args = get_args() + max_epochs = args.max_epochs + i = 0 + while True: + if getattr(args, 'is_training', False): + if max_epochs and i >= max_epochs: + logger.info(f'Training of {i} epochs has been completed, the training has finished.') + break + logger.info(f'The training of Epoch {i} starts...') + for x in iter: + yield x + i += 1 + + +@contextmanager +def _training_context(): + from megatron.training import get_args + args = get_args() + args.is_training = True + try: + yield + finally: + args.is_training = False + + +def _patch_max_epochs(): + # support max_epochs + from megatron.training import training + train_step_origin = training.train_step + + def train_step(*args, **kwargs): + with _training_context(): + try: + return train_step_origin(*args, **kwargs) + except StopIteration: + return {}, True, True, True, 0, None, None + + training.train_step = train_step + + training.cyclic_iter = new_cyclic_iter + + +def _patch_megatron(): + _patch_transformer_engine() + _patch_max_epochs() + + +def init_megatron_env() -> None: + if 'MEGATRON_LM_PATH' not in os.environ: + os.environ['MEGATRON_LM_PATH'] = git_clone_github( + 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.12.0') + with safe_ddp_context(hash_id='megatron-lm'): + if not is_megatron_available(): + subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) + sys.path.insert(0, os.environ['MEGATRON_LM_PATH']) + _patch_megatron() diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d13a8d1b5e51e0c9192b792621340bfe06a6f6f --- /dev/null +++ b/swift/megatron/model/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import gpt +from .constant import MegatronModelType +from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9c9656cf455ec0ba7a2035af3ded7c5e8a57e1 --- /dev/null +++ b/swift/megatron/model/config.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from swift.utils import get_logger + +logger = get_logger() +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim'], + 'model_type': ['model_type'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok'], + 'num_experts': ['num_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + 'moe_aux_loss_coeff': ['router_aux_loss_coef'], +} + + +def convert_hf_config(config) -> Dict[str, Any]: + megatron_config = {} + for k, hf_keys in config_mapping.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + megatron_config[k] = hf_v + break + # compat llama3 + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + logger.info(f'megatron_config: {megatron_config}') + return megatron_config diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..8eebb6aa76a43e70b5b6f83801a97b65e51dd9ce --- /dev/null +++ b/swift/megatron/model/constant.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +class MegatronModelType: + gpt = 'gpt' diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9af9f71117d96920979a3ed4ab7579523a2263 --- /dev/null +++ b/swift/megatron/model/gpt/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import ModelType +from ..constant import MegatronModelType +from ..register import MegatronModelMeta, register_megatron_model +from .config import convert_gpt_hf_config +from .hf2mcore import convert_hf2mcore +from .mcore2hf import convert_mcore2hf +from .model import model_provider + +register_megatron_model( + MegatronModelMeta(MegatronModelType.gpt, [ + ModelType.qwen2, + ModelType.qwen2_5, + ModelType.qwq, + ModelType.qwq_preview, + ModelType.qwen2_5_math, + ModelType.llama, + ModelType.llama3, + ModelType.llama3_1, + ModelType.llama3_2, + ModelType.longwriter_llama3_1, + ModelType.codefuse_codellama, + ModelType.marco_o1, + ModelType.deepseek, + ModelType.deepseek_r1_distill, + ModelType.yi, + ModelType.yi_coder, + ModelType.sus, + ModelType.skywork_o1, + ModelType.openbuddy_llama, + ModelType.openbuddy_llama3, + ModelType.megrez, + ModelType.reflection, + ModelType.numina, + ModelType.ziya, + ModelType.mengzi3, + ModelType.qwen3, + ModelType.qwen2_moe, + ModelType.qwen3_moe, + ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6658a952ab2e7255c50ca4c5451060cbecb288a2 --- /dev/null +++ b/swift/megatron/model/gpt/config.py @@ -0,0 +1,13 @@ +from typing import Any, Dict + +from ..config import convert_hf_config + + +def convert_gpt_hf_config(config) -> Dict[str, Any]: + res = convert_hf_config(config) + model_type = res.get('model_type') + if model_type in {'qwen3', 'qwen3_moe'}: + res['qk_layernorm'] = True + if model_type in {'qwen2_moe', 'qwen3_moe'}: + res.pop('ffn_hidden_size', None) + return res diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py new file mode 100644 index 0000000000000000000000000000000000000000..46525df3c757c6e83aaf0a87a783a7acfde68135 --- /dev/null +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from megatron.training import get_args + + +def set_attn_state(args, mg_attn, hf_attn): + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + + # Copy weights + mg_attn.linear_qkv.weight.data.copy_( + torch.cat([ + hf_attn.q_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), + hf_attn.k_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), + hf_attn.v_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size))) + mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) + + # Copy bias + if args.add_qkv_bias: + mg_attn.linear_qkv.bias.data.copy_( + torch.cat([ + hf_attn.q_proj.bias.reshape((num_query_groups, -1)), + hf_attn.k_proj.bias.reshape((num_query_groups, -1)), + hf_attn.v_proj.bias.reshape((num_query_groups, -1)), + ], + dim=1).reshape(-1)) + if args.qk_layernorm: + mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight) + mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight) + + +def _set_mlp_state(mg_mlp, hf_mlp): + mg_mlp.linear_fc1.weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) + mg_mlp.linear_fc2.weight.data.copy_(hf_mlp.down_proj.weight) + + +def set_mlp_state(args, mg_mlp, hf_mlp): + if args.num_experts: + mg_mlp.router.weight.data.copy_(hf_mlp.gate.weight) + if mg_mlp.shared_experts is not None: + mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) + for expert_idx in range(args.num_experts): + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + + if mg_mlp.shared_experts is not None: + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) + else: + _set_mlp_state(mg_mlp, hf_mlp) + + +def set_layer_state(args, mg_model, hf_model, layer_idx): + mg_layer = mg_model.decoder.layers[layer_idx] + hf_layer = hf_model.model.layers[layer_idx] + + set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) + set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) + + post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight + if args.num_experts: + mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight) + else: + mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight) + mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) + + +def convert_hf2mcore(hf_model, mg_model): + args = get_args() + mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) + if args.untie_embeddings_and_output_weights: + mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) + mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) + for layer_idx in range(args.num_layers): + set_layer_state(args, mg_model, hf_model, layer_idx) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py new file mode 100644 index 0000000000000000000000000000000000000000..6f29abaf0e63482ef7a538f1171a74be3f5ea162 --- /dev/null +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from megatron.training import get_args + + +def set_attn_state(args, mg_attn, hf_attn): + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + # Copy weights + mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups + hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size)) + hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size)) + hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size)) + hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) + + # Copy bias + if args.add_qkv_bias: + mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) + hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1)) + hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1)) + hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1)) + + if args.qk_layernorm: + hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight) + hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) + + +def _set_mlp_state(mg_mlp, hf_mlp): + ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] + hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:ffn_hidden_size]) + hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[ffn_hidden_size:]) + hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight) + + +def set_mlp_state(args, mg_mlp, hf_mlp): + if args.num_experts: + hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight) + if mg_mlp.shared_experts is not None: + hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) + for expert_idx in range(args.num_experts): + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + + if mg_mlp.shared_experts is not None: + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) + else: + _set_mlp_state(mg_mlp, hf_mlp) + + +def set_layer_state(args, mg_model, hf_model, layer_idx): + mg_layer = mg_model.decoder.layers[layer_idx] + hf_layer = hf_model.model.layers[layer_idx] + set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) + set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) + + post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight + if args.num_experts: + post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight) + else: + post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) + hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) + + +def convert_mcore2hf(hf_model, mg_model): + args = get_args() + hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) + if args.untie_embeddings_and_output_weights: + hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) + hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) + for layer_idx in range(args.num_layers): + set_layer_state(args, mg_model, hf_model, layer_idx) diff --git a/swift/megatron/model/gpt/model.py b/swift/megatron/model/gpt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc6bf4fbc32dead7f5cea13cb3eae754c832b3e --- /dev/null +++ b/swift/megatron/model/gpt/model.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.training import get_args +from megatron.training.arguments import core_transformer_config_from_args + +from ..rope import update_rope_inv_freq + + +def model_provider(pre_process=True, post_process=True): + args = get_args() + config = core_transformer_config_from_args(args) + config.variable_seq_lengths = True + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention) + if args.num_experts and args.moe_shared_expert_intermediate_size: + # qwen2_moe/qwen3_moe + transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + rope_scaling_factor=args.rope_scaling_factor, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor) + if args.rope_scaling: + update_rope_inv_freq(model.rotary_pos_emb.inv_freq, args.rope_scaling) + return model diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py new file mode 100644 index 0000000000000000000000000000000000000000..11734757a30142e79f3e414d0c8b85d57f002860 --- /dev/null +++ b/swift/megatron/model/register.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import torch.nn as nn +from transformers import PretrainedConfig + +from swift.llm import MODEL_MAPPING, ModelGroup + +MEGATRON_MODEL_MAPPING = {} + + +@dataclass +class MegatronModelMeta: + megatron_model_type: str + model_types: List[str] + + model_provider: Callable[[], nn.Module] + convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] + convert_mcore2hf: Callable[[nn.Module, nn.Module], None] + convert_hf2mcore: Callable[[nn.Module, nn.Module], None] + + +def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): + megatron_model_type = megatron_model_meta.megatron_model_type + for model_type in megatron_model_meta.model_types: + model_meta = MODEL_MAPPING[model_type] + model_meta.support_megatron = True + if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: + raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') + + MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta + + +_MODEL_META_MAPPING = None + + +def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: + global _MODEL_META_MAPPING + if _MODEL_META_MAPPING is None: + _MODEL_META_MAPPING = {} + for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items(): + for _model_type in megatron_model_meta.model_types: + _MODEL_META_MAPPING[_model_type] = k + if model_type not in _MODEL_META_MAPPING: + return + return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]] diff --git a/swift/megatron/model/rope.py b/swift/megatron/model/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..c127b2c7711811e2ff1092eb967b4d1459fb099f --- /dev/null +++ b/swift/megatron/model/rope.py @@ -0,0 +1,40 @@ +import math +from typing import Any, Dict + +import torch + + +def _to_llama3_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]): + # copy from transformers + factor = rope_scaling['factor'] # `8` in the original implementation + low_freq_factor = rope_scaling['low_freq_factor'] # `1` in the original implementation + high_freq_factor = rope_scaling['high_freq_factor'] # `4` in the original implementation + old_context_len = rope_scaling['original_max_position_embeddings'] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + return inv_freq_llama + + +def _to_linear_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]): + factor = rope_scaling['factor'] + inv_freq /= factor + return inv_freq + + +ROPE_MAPPING = {'llama3': _to_llama3_rope, 'linear': _to_linear_rope} + + +def update_rope_inv_freq(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]) -> None: + new_inv_freq = ROPE_MAPPING[rope_scaling['rope_type']](inv_freq, rope_scaling) + inv_freq.data.copy_(new_inv_freq) diff --git a/swift/megatron/train/__init__.py b/swift/megatron/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6a98be92e5e625a4295b74dee1e80cf0200608 --- /dev/null +++ b/swift/megatron/train/__init__.py @@ -0,0 +1,2 @@ +from .pt import megatron_pt_main +from .sft import megatron_sft_main diff --git a/swift/megatron/train/patcher.py b/swift/megatron/train/patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..76a9862421746a4f8e20f92473269c3f596ce81e --- /dev/null +++ b/swift/megatron/train/patcher.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from contextlib import contextmanager +from functools import wraps + +import torch +from megatron.training import get_args, global_vars, initialize, training + +from swift.utils import JsonlWriter, is_master + + +@contextmanager +def patch_training_log(): + jsonl_writer = None + origin_training_log = training.training_log + + @wraps(origin_training_log) + def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, + report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs): + nonlocal jsonl_writer + args = get_args() + if is_master() and iteration % args.log_interval == 0: + logging_path = os.path.join(args.save, 'logging.jsonl') + logs = {} + for k, v in loss_dict.items(): + if isinstance(v, torch.Tensor): + v = v.item() + logs[k] = round(v, 8) + for k in {'grad_norm', 'params_norm', 'learning_rate'}: + v = locals()[k] + if v is not None: + logs[k] = round(v, 8) + logs['consumed_samples'] = args.consumed_train_samples + logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' + if jsonl_writer is None: + jsonl_writer = JsonlWriter(logging_path, enable_async=True) + jsonl_writer.append(logs) + return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, + num_zeros_in_grad, *_args, **kwargs) + + training.training_log = training_log + try: + yield + finally: + training.training_log = origin_training_log + + +@contextmanager +def patch_megatron_data_collator(data_collator): + origin_build_pretraining_data_loader = training.build_pretraining_data_loader + + def build_pretraining_data_loader(*_args, **kwargs): + args = get_args() + res = origin_build_pretraining_data_loader(*_args, **kwargs) + if res is not None and args.dataloader_type != 'external': + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + try: + yield + finally: + training.build_pretraining_data_loader = origin_build_pretraining_data_loader diff --git a/swift/megatron/train/pt.py b/swift/megatron/train/pt.py new file mode 100644 index 0000000000000000000000000000000000000000..16f4bcd5905615776b0ec04d915f2548213f4e77 --- /dev/null +++ b/swift/megatron/train/pt.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from ..argument import MegatronTrainArguments +from .sft import MegatronSft + + +class MegatronPt(MegatronSft): + args_class = MegatronTrainArguments + args: args_class + + def _prepare_template(self) -> None: + self.args.use_chat_template = False + super()._prepare_template() + self.template.loss_scale = 'all' + + +def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None): + return MegatronPt(args).main() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa3e24f18e381f8f3e8d6b778e9138fbe048dfd --- /dev/null +++ b/swift/megatron/train/sft.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List, Union + +from megatron.core.enums import ModelType +from megatron.training import pretrain + +from swift.llm.train import SwiftSft +from swift.utils import get_logger, is_master, plot_images +from ..argument import MegatronTrainArguments +from ..utils import patch_megatron_tokenizer +from .patcher import patch_megatron_data_collator, patch_training_log +from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider + +logger = get_logger() + + +class MegatronSft(SwiftSft): + args_class = MegatronTrainArguments + args: args_class + + def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None: + self.train_msg = {} + super(SwiftSft, self).__init__(args) + args = self.args + _, self.processor = args.get_model_processor(load_model=False) + patch_megatron_tokenizer(self.processor) + args.init_model_args(self.processor.model_info.config) + self._prepare_template() + self.template.use_megatron = True + args.save_args(args.save) + + def run(self): + args = self.args + + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + data_collator = self.template.data_collator + if args.streaming: + train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) + if val_dataset is not None: + val_dataset = build_streaming_dataloader(args, val_dataset, data_collator) + datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) + datasets_provider.is_distributed = True + + logging_path = os.path.join(args.save, 'logging.jsonl') + logger.info(f'The logging file will be saved in: {logging_path}') + try: + with patch_training_log(), patch_megatron_data_collator(data_collator): + pretrain( + datasets_provider, + args.megatron_model_meta.model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults=args.extra_args) + finally: + # Visualization + if is_master(): + images_dir = os.path.join(args.save, 'images') + logger.info(f'images_dir: {images_dir}') + plot_images(images_dir, args.tensorboard_dir) + + +def megatron_sft_main(args: Union[List[str], MegatronTrainArguments, None] = None): + return MegatronSft(args).main() diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69caa161d16d091fa530c7715f06b6ca95f40d6f --- /dev/null +++ b/swift/megatron/train/utils.py @@ -0,0 +1,229 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from typing import Any, Dict, Optional + +import torch +from megatron.core import mpu +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import StragglerDetector +from megatron.training import get_args, get_timers +from megatron.training.training import cyclic_iter + +from swift.llm import DataLoaderDispatcher + +stimer = StragglerDetector() + + +def get_swift_datasets_provider(train_dataset, val_dataset): + + def swift_datasets_provider(train_val_test_num_samples): + return train_dataset, val_dataset, None + + return swift_datasets_provider + + +class MegatronDataLoaderDispatcher(DataLoaderDispatcher): + + @property + def group(self): + return mpu.get_data_parallel_group() + + +def build_streaming_dataloader(args, dataset, collate_fn): + base_dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate_fn, + batch_size=args.micro_batch_size, + prefetch_factor=args.dataloader_prefetch_factor, + persistent_workers=args.dataloader_persistent_workers, + ) + return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader))) + + +def get_batch_on_this_tp_rank(data_iterator): + # copy from megatron-lm + + args = get_args() + + def _broadcast(item): + if item is not None: + torch.distributed.broadcast( + item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + + if mpu.get_tensor_model_parallel_rank() == 0: + + try: + data = next(data_iterator) + except StopIteration: + seq_length = -1 + else: + tokens = data['input_ids'] + seq_length = tokens.shape[1] + batch = { + 'tokens': tokens.cuda(non_blocking=True), + 'labels': data['labels'].cuda(non_blocking=True), + 'attention_mask': + None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), + 'position_ids': data['position_ids'].cuda(non_blocking=True) + } + seq_length = torch.tensor(seq_length).cuda(non_blocking=True) + _broadcast(seq_length) + if seq_length.item() == -1: + return {} + if args.pipeline_model_parallel_size == 1: + _broadcast(batch['tokens']) + _broadcast(batch['labels']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_first_stage(): + _broadcast(batch['tokens']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_last_stage(): + _broadcast(batch['labels']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + else: + seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + _broadcast(seq_length) + if seq_length.item() == -1: + return {} + micro_batch_size = 1 # use qkv_format 'thd' + tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) + labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) + if args.create_attention_mask_in_dataloader: + attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length), + dtype=torch.bool, + device=torch.cuda.current_device()) + else: + attention_mask = None + position_ids = torch.empty((micro_batch_size, seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + + if args.pipeline_model_parallel_size == 1: + _broadcast(tokens) + _broadcast(labels) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_first_stage(): + labels = None + + _broadcast(tokens) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_last_stage(): + tokens = None + + _broadcast(labels) + _broadcast(attention_mask) + _broadcast(position_ids) # compat packing & cp + + batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids} + + return batch + + +def get_packed_seq_params(position_ids: torch.Tensor) -> Optional[PackedSeqParams]: + position_ids_f = position_ids.flatten() + indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) + + cu_seqlens = torch.cat([ + indices_q[position_ids_f == 0], + torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), + ]) + + max_length = position_ids_f.max() + 1 + return PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_length, + max_seqlen_kv=max_length, + qkv_format='thd') + + +def _split_tokens(tokens, cu_seqlens): + assert tokens.shape[0] == 1, f'tokens.shape: {tokens.shape}' + new_tokens = [] + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + for i in range(cu_seqlens.shape[0] - 1): + val = tokens[:, cu_seqlens[i]:cu_seqlens[i + 1]] + val = val.view( + tokens.shape[0], + 2 * cp_size, + val.shape[1] // (2 * cp_size), + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', + pin_memory=True).cuda(non_blocking=True) + val = val.index_select(1, index) + new_tokens.append(val.view(tokens.shape[0], -1)) + return torch.cat(new_tokens, dim=1) + + +def get_batch_on_this_cp_rank(batch: Dict[str, Any]): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + + # With causal masking, each token only attends to its prior tokens. Simply split + # sequence into CP chunks can result in severe load imbalance. That's to say, chunks + # at the end of sequence have bigger workload than others. To address this issue, + # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 + # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so + # that we can get balanced workload among GPUs in a context parallel group. + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + packed_seq_params = batch['packed_seq_params'] + for key, val in batch.items(): + if key == 'packed_seq_params': + continue + if val is not None: + batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q) + + return batch + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + if not batch: + return batch + batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + return batch.values() + + +def forward_step(data_iterator, model): + from pretrain_gpt import loss_func + + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + data = get_batch(data_iterator) + if not data: + raise StopIteration + tokens, labels, attention_mask, position_ids, packed_seq_params = data + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) + loss_mask = None if labels is None else (labels != -100).float() + return output_tensor, partial(loss_func, loss_mask) diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d2b722a2cf06a94691e9546b94247bca0998367 --- /dev/null +++ b/swift/megatron/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .convert import convert_hf2mcore, convert_mcore2hf +from .patcher import patch_megatron_tokenizer diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..42d37b945e1372af1662c8ce80e8eeea98523815 --- /dev/null +++ b/swift/megatron/utils/convert.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +from megatron.training.checkpointing import load_checkpoint +from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint +from megatron.training.initialize import initialize_megatron +from megatron.training.utils import get_ltor_masks_and_position_ids + +from swift.llm import ExportArguments, get_model_tokenizer, get_template, save_checkpoint +from swift.utils import get_logger, get_n_params_grads +from ..argument import MegatronArguments +from ..model import get_megatron_model_meta +from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard + +logger = get_logger() + + +def test_convert_precision(hf_model, mg_model, processor): + torch_dtype = hf_model.dtype + template = get_template(hf_model.model_meta.template, processor) + input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids'] + input_ids = torch.tensor(input_ids)[None].to('cuda') + hf_model.to('cuda') + hf_model.to(torch.float32) + with torch.inference_mode(): + hf_logits = hf_model(input_ids).logits + hf_model.to(torch_dtype) + hf_model.to('cpu') + + attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) + mg_model.to('cuda') + mg_model.to(torch.float32) + with torch.inference_mode(): + mg_logits = mg_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + mg_model.to(torch_dtype) + mg_model.to('cpu') + + mean_diff = (mg_logits - hf_logits).abs().mean().item() + max_diff = (mg_logits - hf_logits).abs().max().item() + print(f'mean_diff: {mean_diff}, max_diff: {max_diff}') + hf_tokens = hf_logits.argmax(-1) + mg_tokens = mg_logits.argmax(-1) + print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}') + assert mean_diff < 0.1 + assert (hf_tokens == mg_tokens).all() + + +convert_kwargs = { + 'use_cpu_initialization': True, + 'no_save_optim': True, + 'no_save_rng': True, + 'no_load_optim': True, + 'no_load_rng': True, + 'no_masked_softmax_fusion': True, + 'no_bias_dropout_fusion': True, + 'no_bias_swiglu_fusion': True, + 'no_rope_fusion': True +} + + +def convert_hf2mcore(args: ExportArguments) -> None: + kwargs = args.get_model_kwargs() + hf_model, processor = get_model_tokenizer(**kwargs) + if args.thread_count is None: + checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 + args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB + patch_torch_dist_shard(args.thread_count) + + megatron_model_meta = get_megatron_model_meta(args.model_type) + assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' + kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) + patch_megatron_tokenizer(processor) + extra_args = megatron_args.parse_to_megatron() + initialize_megatron(args_defaults=extra_args) + + mg_model = megatron_model_meta.model_provider() + logger.info('Megatron model created successfully.') + megatron_model_meta.convert_hf2mcore(hf_model, mg_model) + if args.test_convert_precision: + test_convert_precision(hf_model, mg_model, processor) + logger.info('Successfully transferred HF model weights to MG model.') + mg_save_checkpoint(1, [mg_model], None, None, 0) + args.save_args() + logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') + + +def convert_mcore2hf(args: ExportArguments) -> None: + kwargs = args.get_model_kwargs() + hf_model, processor = get_model_tokenizer(**kwargs) + if args.thread_count is None: + checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 + args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB + patch_torch_dist_shard(args.thread_count) + + megatron_model_meta = get_megatron_model_meta(args.model_type) + assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' + kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype) + patch_megatron_tokenizer(processor) + extra_args = megatron_args.parse_to_megatron() + initialize_megatron(args_defaults=extra_args) + + mg_model = megatron_model_meta.model_provider() + load_checkpoint([mg_model], None, None, strict=True) + logger.info('Megatron model created successfully.') + megatron_model_meta.convert_mcore2hf(hf_model, mg_model) + if args.test_convert_precision: + test_convert_precision(hf_model, mg_model, processor) + logger.info('Successfully transferred MG model weights to HF model.') + save_checkpoint( + hf_model, + processor, + args.output_dir, + safe_serialization=args.safe_serialization, + model_dirs=[args.mcore_model, args.model_dir], + max_shard_size=args.max_shard_size, + additional_saved_files=hf_model.model_meta.additional_saved_files) + args.save_args() + logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4aed76fcb7e0dd6aff7b31641d34b619f29a8a --- /dev/null +++ b/swift/megatron/utils/patcher.py @@ -0,0 +1,26 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy +from megatron.training import get_args, global_vars, initialize, training + +from swift.utils import get_logger + +logger = get_logger() + + +def patch_megatron_tokenizer(tokenizer): + + def build_tokenizer(args): + args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size + return tokenizer + + global_vars.build_tokenizer = build_tokenizer + + +def patch_torch_dist_shard(thread_count): + __init__ = TorchDistSaveShardedStrategy.__init__ + + def __new_init__(*args, **kwargs): + kwargs['thread_count'] = thread_count + return __init__(*args, **kwargs) + + TorchDistSaveShardedStrategy.__init__ = __new_init__ diff --git a/swift/plugin/.ipynb_checkpoints/__init__-checkpoint.py b/swift/plugin/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..109a4294314c7869d1b7e2cd7f1003c0c23aa50a --- /dev/null +++ b/swift/plugin/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .callback import extra_callbacks + from .loss import LOSS_MAPPING, get_loss_func + from .loss_scale import loss_scale_map + from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu + from .optimizer import optimizers_map + from .agent_template import agent_templates + from .tuner import Tuner, extra_tuners, PeftTuner + from .prm import prms, PRM + from .orm import orms, ORM + from .multi_turn import multi_turns + from .rm_plugin import rm_plugins + +else: + _import_structure = { + 'callback': ['extra_callbacks'], + 'loss': ['LOSS_MAPPING', 'get_loss_func'], + 'loss_scale': ['loss_scale_map'], + 'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'], + 'optimizer': ['optimizers_map'], + 'agent_template': ['agent_templates'], + 'tuner': ['Tuner', 'extra_tuners', 'PeftTuner'], + 'prm': ['prms', 'PRM'], + 'orm': ['orms', 'ORM'], + 'multi_turn': ['multi_turns'], + 'rm_plugin': ['rm_plugins'] + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/plugin/.ipynb_checkpoints/orm-checkpoint.py b/swift/plugin/.ipynb_checkpoints/orm-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f1980f9067eab862bae2e01d09129d0d4fa750 --- /dev/null +++ b/swift/plugin/.ipynb_checkpoints/orm-checkpoint.py @@ -0,0 +1,406 @@ +import os +import re +from typing import Dict, List, Union + +import json + +from swift.llm import InferRequest + + +class ORM: + + def __call__(self, **kwargs) -> List[float]: + raise NotImplementedError + + +class ReactORM(ORM): + + @staticmethod + def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list): + f1 = [] + for i in range(len(action_pred)): + ref_action = action_ref[i] + pred_action = action_pred[i] + + ref_input = ref_list[i] + cand_input = cand_list[i] + + ref_is_json = False + try: + ref_input_json = json.loads(ref_input) + ref_is_json = True + except Exception: + ref_input_json = ref_input + + cand_is_json = False + try: + cand_input_json = json.loads(cand_input) + cand_is_json = True + except Exception: + cand_input_json = cand_input + + if ref_action != pred_action or (ref_is_json ^ cand_is_json): + f1.append(0) + elif not ref_is_json and not cand_is_json: + rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json]) + if rougel is None or rougel < 10: + f1.append(0) + elif 10 <= rougel < 20: + f1.append(0.1) + else: + f1.append(1) + else: + if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict): + # This cannot be happen, but: + # line 62, in evaluate_action_reward + # for k, v in ref_input_json.items(): + # AttributeError: 'str' object has no attribute 'items' + # print(f'>>>>>>ref_input_json: {ref_input_json}, cand_input_json: {cand_input_json}') + f1.append(0) + continue + + half_match = 0 + full_match = 0 + if ref_input_json == {}: + if cand_input_json == {}: + f1.append(1) + else: + f1.append(0) + else: + for k, v in ref_input_json.items(): + if k in cand_input_json.keys(): + if cand_input_json[k] == v: + full_match += 1 + else: + half_match += 1 + + recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30) + precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30) + try: + f1.append((2 * recall * precision) / (recall + precision)) + except Exception: + f1.append(0.0) + + if f1[0] == 1.0: + return True + else: + return False + + @staticmethod + def parse_action(text): + if 'Action Input:' in text: + input_idx = text.rindex('Action Input:') + action_input = text[input_idx + len('Action Input:'):].strip() + else: + action_input = '{}' + + if 'Action:' in text: + action_idx = text.rindex('Action:') + action = text[action_idx + len('Action:'):].strip() + if 'Action Input:' in action: + input_idx = action.index('Action Input:') + action = action[:input_idx].strip() + else: + action = 'none' + return action, action_input + + @staticmethod + def parse_output(text): + action, action_input = ReactORM.parse_action(text) + return action, action_input + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]: + rewards = [] + if not isinstance(infer_requests[0], str): + predictions = [request['messages'][-1]['content'] for request in infer_requests] + else: + predictions = infer_requests + for prediction, ground_truth in zip(predictions, solution): + if prediction.endswith('Observation:'): + prediction = prediction[:prediction.index('Observation:')].strip() + action_ref = [] + action_input_ref = [] + action_pred = [] + action_input_pred = [] + reference = ground_truth + prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip() + ref_action, ref_input = ReactORM.parse_output(reference) + pred_action, pred_input = ReactORM.parse_output(prediction) + action_ref.append(ref_action) + action_input_ref.append(ref_input) + if pred_action is None: + action_pred.append('none') + else: + action_pred.append(pred_action) + + if pred_input is None: + action_input_pred.append('{}') + else: + action_input_pred.append(pred_input) + + reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref) + rewards.append(float(reward)) + return rewards + + @staticmethod + def evaluate_rougel(cand_list: list, ref_list: list): + if len(ref_list) == 0: + return None + try: + from rouge import Rouge + rouge = Rouge() + rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True) + rougel = rouge_score['rouge-l']['f'] + return rougel + except Exception: + return None + + +class MathORM(ORM): + + def __init__(self): + from transformers.utils import strtobool + self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False')) + if self.use_opencompass: + from opencompass.datasets.math import MATHEvaluator + self.evaluator = MATHEvaluator() + + @staticmethod + def check_terminate(answers: Union[str, List[str]]) -> List[bool]: + if isinstance(answers, str): + answers = [answers] + results = [] + for answer in answers: + results.append('\\boxed' in answer) + return results + + @staticmethod + def extract_boxed_result(text): + pattern = r'\\boxed{([^}]*)}' + match = re.search(pattern, text) + if match: + return match.group(1).strip() + else: + return text + + @staticmethod + def clean_latex(latex_str): + latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str) + latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '') + return latex_str.strip() + + @staticmethod + def parse_expression(latex_str): + from sympy import simplify + from sympy.parsing.latex import parse_latex + try: + expr = parse_latex(latex_str) + return simplify(expr) + except Exception: + return None + + @staticmethod + def compare_consecutive(first, second): + cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]] + parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list] + if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'): + value = parsed_exprs[0].equals(parsed_exprs[1]) + else: + value = parsed_exprs[0] == parsed_exprs[1] + if value is None: + value = False + return value + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], + **kwargs) -> List[float]: + rewards = [] + predictions = [request.messages[-1]['content'] for request in infer_requests] + for prediction, ground_truth in zip(predictions, ground_truths): + if '# Answer' in prediction: + prediction = prediction.split('# Answer')[1] + if '# Answer' in ground_truth: + ground_truth = ground_truth.split('# Answer')[1] + prediction = prediction.strip() + ground_truth = ground_truth.strip() + prediction = MathORM.extract_boxed_result(prediction) + ground_truth = MathORM.extract_boxed_result(ground_truth) + if self.use_opencompass: + reward = self.evaluator.is_equiv(prediction, ground_truth) + else: + reward = MathORM.compare_consecutive(prediction, ground_truth) + rewards.append(float(reward)) + return rewards + + +class MathAccuracy(ORM): + + def __init__(self): + import importlib.util + assert importlib.util.find_spec('math_verify') is not None, ( + "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") + + def __call__(self, completions, solution, **kwargs) -> List[float]: + from latex2sympy2_extended import NormalizationConfig + from math_verify import LatexExtractionConfig, parse, verify + rewards = [] + for content, sol in zip(completions, solution): + gold_parsed = parse(sol, extraction_mode='first_match') + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode='first_match', + ) + # edge case + try: + reward = float(verify(gold_parsed, answer_parsed)) + except Exception: + reward = 0.0 + else: + # If the gold solution is not parseable, we reward 0 to skip this example + reward = 0.0 + rewards.append(reward) + return rewards + + +class Format(ORM): + + def __call__(self, completions, **kwargs) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r'^.*?\s*.*?(?![\s\S])' + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] + return [1.0 if match else 0.0 for match in matches] + + +class ReActFormat(ORM): + + def __call__(self, completions, **kwargs) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r'^.*?\s*Action:.*?Action Input:.*?$' + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] + return [1.0 if match else 0.0 for match in matches] + + +class CosineReward(ORM): + # https://arxiv.org/abs/2502.03373 + def __init__(self, + tokenizer=None, + cosine_min_len_value_wrong: float = -0.5, + cosine_max_len_value_wrong: float = 0.0, + cosine_min_len_value_correct: float = 1.0, + cosine_max_len_value_correct: float = 0.5, + cosine_max_len: int = 1000, + accuracy_orm=None): + self.tokenizer = tokenizer + self.min_len_value_wrong = cosine_min_len_value_wrong + self.max_len_value_wrong = cosine_max_len_value_wrong + self.min_len_value_correct = cosine_min_len_value_correct + self.max_len_value_correct = cosine_max_len_value_correct + self.max_len = cosine_max_len + self.accuracy_orm = accuracy_orm or MathAccuracy() + + @staticmethod + def cosfn(t, T, min_value, max_value): + import math + return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2 + + def __call__(self, completions, solution, **kwargs) -> List[float]: + acc_rewards = self.accuracy_orm(completions, solution, **kwargs) + rewards = [] + for content, acc_reward in zip(completions, acc_rewards): + is_correct = acc_reward >= 1. + if is_correct: + # Swap min/max for correct answers + min_value = self.max_len_value_correct + max_value = self.min_len_value_correct + else: + min_value = self.max_len_value_wrong + max_value = self.min_len_value_wrong + gen_len = len(self.tokenizer.encode(content)) + reward = self.cosfn(gen_len, self.max_len, min_value, max_value) + rewards.append(reward) + return rewards + + +class RepetitionPenalty(ORM): + # https://arxiv.org/abs/2502.03373 + def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0): + self.ngram_size = repetition_n_grams + self.max_penalty = repetition_max_penalty + + @staticmethod + def zipngram(text: str, ngram_size: int): + words = text.lower().split() + return zip(*[words[i:] for i in range(ngram_size)]) + + def __call__(self, completions, **kwargs) -> List[float]: + """ + reward function the penalizes repetitions + + Args: + completions: List of model completions + """ + rewards = [] + for completion in completions: + if completion == '': + rewards.append(0.0) + continue + if len(completion.split()) < self.ngram_size: + rewards.append(0.0) + continue + + ngrams = set() + total = 0 + for ng in self.zipngram(completion, self.ngram_size): + ngrams.add(ng) + total += 1 + + scaling = 1 - len(ngrams) / total + reward = scaling * self.max_penalty + rewards.append(reward) + return rewards + + +class SoftOverlong(ORM): + + def __init__(self, tokenizer, soft_max_length, soft_cache_length): + self.tokenizer = tokenizer + assert soft_cache_length < soft_max_length + self.soft_max_length = soft_max_length + self.soft_cache_length = soft_cache_length + + def __call__(self, completions, **kwargs) -> List[float]: + rewards = [] + for completion in completions: + completion_length = len(self.tokenizer.encode(completion)) + expected_len = self.soft_max_length - self.soft_cache_length + exceed_len = completion_length - expected_len + rewards.append(min(-exceed_len / self.soft_cache_length, 0)) + return rewards + + +orms = { + 'toolbench': ReactORM, + 'math': MathORM, + 'accuracy': MathAccuracy, + 'format': Format, + 'react_format': ReActFormat, + 'cosine': CosineReward, + 'repetition': RepetitionPenalty, + 'soft_overlong': SoftOverlong, +} diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..109a4294314c7869d1b7e2cd7f1003c0c23aa50a --- /dev/null +++ b/swift/plugin/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .callback import extra_callbacks + from .loss import LOSS_MAPPING, get_loss_func + from .loss_scale import loss_scale_map + from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu + from .optimizer import optimizers_map + from .agent_template import agent_templates + from .tuner import Tuner, extra_tuners, PeftTuner + from .prm import prms, PRM + from .orm import orms, ORM + from .multi_turn import multi_turns + from .rm_plugin import rm_plugins + +else: + _import_structure = { + 'callback': ['extra_callbacks'], + 'loss': ['LOSS_MAPPING', 'get_loss_func'], + 'loss_scale': ['loss_scale_map'], + 'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'], + 'optimizer': ['optimizers_map'], + 'agent_template': ['agent_templates'], + 'tuner': ['Tuner', 'extra_tuners', 'PeftTuner'], + 'prm': ['prms', 'PRM'], + 'orm': ['orms', 'ORM'], + 'multi_turn': ['multi_turns'], + 'rm_plugin': ['rm_plugins'] + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/plugin/__pycache__/__init__.cpython-310.pyc b/swift/plugin/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2123a2572389d63dcbdb9a05e8d57059c78bc193 Binary files /dev/null and b/swift/plugin/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/callback.cpython-310.pyc b/swift/plugin/__pycache__/callback.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70665a44fbee1b92aff1da83fddff0852e2668c2 Binary files /dev/null and b/swift/plugin/__pycache__/callback.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/loss.cpython-310.pyc b/swift/plugin/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d378ffdd253ba23bd4a8ceb5c104aa081d892039 Binary files /dev/null and b/swift/plugin/__pycache__/loss.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/metric.cpython-310.pyc b/swift/plugin/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a3261fe758ba9517ba11d4af33f79a756655f1 Binary files /dev/null and b/swift/plugin/__pycache__/metric.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/multi_turn.cpython-310.pyc b/swift/plugin/__pycache__/multi_turn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a36ae7845ff09118f2bd4b6fd89c7fb3f7e327c Binary files /dev/null and b/swift/plugin/__pycache__/multi_turn.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/orm.cpython-310.pyc b/swift/plugin/__pycache__/orm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b3f4062cf5f4055e0089b050d27e905f7c11180 Binary files /dev/null and b/swift/plugin/__pycache__/orm.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/rm_plugin.cpython-310.pyc b/swift/plugin/__pycache__/rm_plugin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5152cf4ea180259c1a70694b243bb2ad6567cbb Binary files /dev/null and b/swift/plugin/__pycache__/rm_plugin.cpython-310.pyc differ diff --git a/swift/plugin/__pycache__/tuner.cpython-310.pyc b/swift/plugin/__pycache__/tuner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..503ec3dbcf4ebb7b69e54a2fb100adfbfcaf99e4 Binary files /dev/null and b/swift/plugin/__pycache__/tuner.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__init__.py b/swift/plugin/agent_template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35f40f9308aa70ad0a608cb3158fd0207578c5e9 --- /dev/null +++ b/swift/plugin/agent_template/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .base import BaseAgentTemplate +from .extra import ReactGRPOAgentTemplate +from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate +from .hermes import HermesAgentTemplate +from .llama import Llama3AgentTemplate, Llama4AgentTemplate +from .qwen import QwenEnAgentTemplate, QwenEnParallelAgentTemplate, QwenZhAgentTemplate, QwenZhParallelAgentTemplate +from .react import ReactEnAgentTemplate, ReactZnAgentTemplate +from .toolbench import ToolBenchAgentTemplate + +agent_templates = { + # ref: https://qwen.readthedocs.io/zh-cn/latest/framework/function_call.html#function-calling-templates + 'react_en': ReactEnAgentTemplate, + 'react_zh': ReactZnAgentTemplate, + # ref: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py + 'qwen_en': QwenEnAgentTemplate, + 'qwen_zh': QwenZhAgentTemplate, + 'qwen_en_parallel': QwenEnParallelAgentTemplate, + 'qwen_zh_parallel': QwenZhParallelAgentTemplate, + 'hermes': HermesAgentTemplate, + 'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench + 'glm4': GLM4AgentTemplate, + 'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414 + 'llama3': Llama3AgentTemplate, + 'llama4': Llama4AgentTemplate, + # extra + 'react_grpo': ReactGRPOAgentTemplate +} diff --git a/swift/plugin/agent_template/__pycache__/__init__.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af930134339c6933784f85e712983a25898e4be3 Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/base.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee67e6eb32076e1e4275e5018ed28023f03ccfd5 Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/base.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/extra.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/extra.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd322399df6d6fb7c4a7c0004f9a879533d8152f Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/extra.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c7946c8feecfc80a526db7458bfa03da4a4c66e Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/hermes.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/hermes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d49bd73d1e5974851d17f80a777796888b0ca9eb Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/hermes.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b1e7a2e6064b5f01ec4bd3ba3763656f6019675 Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ff8c9863c5dbb2266bcabcfe9096413b7f4d33 Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5be026465f4edc8858101f432ab99856b39545e Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/__pycache__/toolbench.cpython-310.pyc b/swift/plugin/agent_template/__pycache__/toolbench.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5d5f9ebf4f4e833dfc13a95447546eedfa83af6 Binary files /dev/null and b/swift/plugin/agent_template/__pycache__/toolbench.cpython-310.pyc differ diff --git a/swift/plugin/agent_template/base.py b/swift/plugin/agent_template/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a24fc9d49b804e0fa2eefe8f8f8803cf70a7ddaa --- /dev/null +++ b/swift/plugin/agent_template/base.py @@ -0,0 +1,158 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import ast +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union + +import json + +if TYPE_CHECKING: + from swift.llm.infer import Function + from swift.llm.template import Prompt + + +@dataclass +class AgentKeyword: + action: str = 'Action:' + action_input: str = 'Action Input:' + observation: str = 'Observation:' + + +@dataclass +class ToolDesc: + name_for_model: str + name_for_human: str + description_for_model: str + parameters: str + args_format: str + + +class ReactCompatMixin: + keyword = AgentKeyword() + + @staticmethod + def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']: + from swift.llm.template import split_str_parts_by + from swift.llm.infer import Function + agent_parts = split_str_parts_by(response, list(asdict(keyword).values())) + functions = [] + action_content = None + + for part in agent_parts: + key, content = part['key'].lower(), part['content'] + if action_content is None and key == keyword.action.lower(): + action_content = content + elif action_content is not None and key == keyword.action_input.lower(): + functions.append(Function(name=action_content, arguments=content)) + action_content = None + + return functions + + def get_toolcall(self, response: str) -> List['Function']: + functions = self._split_action_action_input(response, self.keyword) + if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword: + # compat react + functions = self._split_action_action_input(response, ReactCompatMixin.keyword) + return functions + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + assert len(tool_messages) > 0 + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + if not assistant_content.endswith(self.keyword.observation): + if not assistant_content.endswith('\n'): + assistant_content += '\n' + assistant_content += self.keyword.observation + res = [] + for i, tool_message in enumerate(tool_messages): + if i > 0: + res.append(self.keyword.observation) + tool_content = tool_message['content'] + res.append(tool_content) + if not tool_content.endswith('\n'): + res.append('\n') + else: + res = [] + for tool_message in tool_messages: + res.append(tool_message['content']) + return assistant_content, res + + @staticmethod + def _parse_tool_call(content) -> Dict[str, Any]: + obj = BaseAgentTemplate._parse_json(content) + name = obj['name'] + arguments = obj.get('arguments') or obj.get('parameters') + arguments = BaseAgentTemplate._parse_json(arguments) + assert arguments is not None, f'content: {content}' + return {'name': name, 'arguments': arguments} + + def _format_tool_calls(self, tool_call_messages) -> str: + # -> assistant_content + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n' + f'{self.keyword.action_input} {tool_call["arguments"]}\n') + tool_calls.append(self.keyword.observation) + return ''.join(tool_calls) + + +class BaseAgentTemplate(ReactCompatMixin, ABC): + + @staticmethod + def _get_tool_name(tool): + return tool.get('name_for_model') or tool.get('name') + + @staticmethod + def unwrap_tool(tool): + assert isinstance(tool, dict), f'tool: {tool}' + if 'type' in tool and 'function' in tool: + tool = tool['function'] + return tool + + @staticmethod + def wrap_tool(tool): + assert isinstance(tool, dict), f'tool: {tool}' + if 'type' not in tool and 'function' not in tool: + tool = {'type': 'function', 'function': tool} + return tool + + @staticmethod + def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc: + tool = BaseAgentTemplate.unwrap_tool(tool) + name_for_model = BaseAgentTemplate._get_tool_name(tool) + name_for_human = tool.get('name_for_human') or name_for_model + + description = tool.get('description') or tool.get('description_for_model') + parameters = tool.get('parameters') or {} + parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False) + args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.' + tool_desc = ToolDesc( + name_for_model=name_for_model, + name_for_human=name_for_human, + description_for_model=description, + parameters=parameters, + args_format=args_format) + assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}' + return tool_desc + + @staticmethod + def _parse_json(json_str: str) -> Optional[Any]: + if not isinstance(json_str, str): + return json_str + try: + res = json.loads(json_str) + except json.JSONDecodeError: + try: + res = ast.literal_eval(json_str) + except Exception: + return + return res + + @abstractmethod + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + pass diff --git a/swift/plugin/agent_template/extra.py b/swift/plugin/agent_template/extra.py new file mode 100644 index 0000000000000000000000000000000000000000..019f05a786c1a178a715c3a1522690351617c5fc --- /dev/null +++ b/swift/plugin/agent_template/extra.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from .base import BaseAgentTemplate + + +class ReactGRPOAgentTemplate(BaseAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names = [] + tool_descs = [] + for tool in tools: + tool_desc = self._parse_tool(tool, 'en') + tool_names.append(tool_desc.name_for_model) + tool_descs.append( + f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. ' + f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} ' + f'Parameters: {tool_desc.parameters} {tool_desc.args_format}') + + return """A conversation for tool calling between User and Assistant. The user asks a question which may be solved by calling tools, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process should be enclosed within tags and answer should follow the ReACT format(Action:xxx\nAction Input:xxx), i.e., reasoning process here Action: action here\nAction Input: parameters here + +Answer the following questions as best as you can. You have access to the following tools: + +""" + '\n\n'.join(tool_descs) + f""" + +Use the following format: + +you should always think about what to do +Action: the action to take, should be one of [{','.join(tool_names)}] +Action Input: the input to the action +Observation: the result of the action, given by the actual calling +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Final Answer: the final answer to the original input question + +Begin! +""" # noqa diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfea2ab651d316085e042b53765ab71e562bf9f --- /dev/null +++ b/swift/plugin/agent_template/glm4.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import json + +from .base import BaseAgentTemplate + +if TYPE_CHECKING: + from swift.llm.infer import Function + from swift.llm.template import Prompt + + +class GLM4AgentTemplate(BaseAgentTemplate): + is_glm4_0414 = False + + @staticmethod + def _find_function_call(single_content: str) -> Optional['Function']: + from swift.llm.infer import Function + single_content = single_content.replace('<|observation|>', '') + pattern = re.compile(r'([^\n`]*?)\n({.*?})(?=\w*\n|$)', re.DOTALL) + matches = pattern.findall(single_content) + if not matches: + return + + name, arguments = matches[0] + return Function(name=name, arguments=arguments) + + def get_toolcall(self, response: str) -> List['Function']: + toolcall_list = response.split('<|assistant|>') + functions = [] + for toolcall in toolcall_list: + function = self._find_function_call(toolcall) + if function: + functions.append(function) + if len(functions) == 0: + # compat react_en + return super().get_toolcall(response) + return functions + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_descs = [] + for tool in tools: + tool = self.unwrap_tool(tool) + name = self._get_tool_name(tool) + tool_descs.append(f'## {name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n' + '在调用上述函数时,请使用 Json 格式表示调用的参数。') + glm4_system = '你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n' # noqa + return ('' if self.is_glm4_0414 else glm4_system) + """# 可用工具 + +""" + '\n'.join(tool_descs) + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + return super()._format_tool_responses(assistant_content, tool_messages) + res = ['\n'] + for i, tool_message in enumerate(tool_messages): + tool_content = tool_message['content'] + if i > 0: + res.append('<|observation|>\n') + res.append(tool_content) + res.append('<|assistant|>\n') + return assistant_content, res + + def _format_tool_calls(self, tool_call_messages) -> str: + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_calls.append(f'{tool_call["name"]}\n{tool_call["arguments"]}') + return '<|assistant|>'.join(tool_calls) + '<|observation|>' + + +class GLM4_0414AgentTemplate(GLM4AgentTemplate): + is_glm4_0414 = True diff --git a/swift/plugin/agent_template/hermes.py b/swift/plugin/agent_template/hermes.py new file mode 100644 index 0000000000000000000000000000000000000000..28ab23fa3d803a1f62b209cffcd168a361512483 --- /dev/null +++ b/swift/plugin/agent_template/hermes.py @@ -0,0 +1,78 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import json + +from .base import BaseAgentTemplate + +if TYPE_CHECKING: + from swift.llm.infer import Function + from swift.llm.template import Prompt + + +class HermesAgentTemplate(BaseAgentTemplate): + + def get_toolcall(self, response: str) -> List['Function']: + from swift.llm.infer import Function + res_list = re.findall(r'(.+?)', response, re.DOTALL) + functions = [] + for res in res_list: + res = self._parse_json(res) + if isinstance(res, dict) and 'name' in res and 'arguments' in res: + functions.append(Function(name=res['name'], arguments=res['arguments'])) + if len(functions) == 0: + # compat react_en + return super().get_toolcall(response) + return functions + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + return super()._format_tool_responses(assistant_content, tool_messages) + if hasattr(self, 'template_meta'): + prompt = self.template_meta.prompt + chat_sep = self.template_meta.chat_sep + else: + prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'] + chat_sep = ['<|im_end|>\n'] + res = chat_sep.copy() + res_tool = [] + for tool_message in tool_messages: + tool_content = tool_message['content'] + res_tool.append(f'\n{tool_content}\n') + total_tool = '\n'.join(res_tool) + for context in prompt: + if isinstance(context, str): + context = context.replace('{{QUERY}}', total_tool) + res.append(context) + return assistant_content, res + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools] + return f"""{system} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +""" + '\n'.join(tool_descs) + """ + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +""" + + def _format_tool_calls(self, tool_call_messages): + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_calls.append(f'\n{json.dumps(tool_call, ensure_ascii=False)}\n') + return '\n'.join(tool_calls) diff --git a/swift/plugin/agent_template/llama.py b/swift/plugin/agent_template/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..a247d8420a13d11ad68fbd97bc669d20741edd87 --- /dev/null +++ b/swift/plugin/agent_template/llama.py @@ -0,0 +1,78 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import json + +from .base import BaseAgentTemplate + +if TYPE_CHECKING: + from swift.llm.infer import Function + from swift.llm.template import Prompt + + +class Llama3AgentTemplate(BaseAgentTemplate): + eom_token = '<|eom_id|>' + start_token = '<|start_header_id|>' + end_token = '<|end_header_id|>' + eot_token = '<|eot_id|>' + + def get_toolcall(self, response: str) -> List['Function']: + from swift.llm.infer import Function + if response.endswith(self.eom_token): + response = response[:-len(self.eom_token)] + functions = [] + res_list = re.findall(r'{[^{]*?"name":.*?"parameters":\s*?{.*?}\s*?}', response, re.DOTALL) + for res in res_list: + res = self._parse_json(res) + if isinstance(res, dict) and 'name' in res and 'parameters' in res: + functions.append(Function(name=res['name'], arguments=res['parameters'])) + if len(functions) == 0: + # compat react_en + return super().get_toolcall(response) + return functions + + def _format_tool_responses( + self, + assistant_content: str, + tool_messages, + ) -> Tuple[str, 'Prompt']: + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content + if with_action: + return super()._format_tool_responses(assistant_content, tool_messages) + res = [self.eot_token] + for tool_message in tool_messages: + tool_content = tool_message['content'] + res.append(f'{self.start_token}tool{self.end_token}\n\n{tool_content}{self.eot_token}') + res.append(f'{self.start_token}assistant{self.end_token}\n\n') + return assistant_content, res + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + assert user_message is not None + user_content = user_message['content'] + tool_descs = [json.dumps(tool, ensure_ascii=False, indent=4) for tool in tools] + new_user_content = """Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + +Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. + +""" + '\n\n'.join(tool_descs) + f""" + +{user_content}""" # noqa + user_message['content'] = new_user_content + return system + + def _format_tool_calls(self, tool_call_messages) -> str: + tool_calls = [] + for message in tool_call_messages: + tool_call = self._parse_tool_call(message['content']) + tool_call['parameters'] = tool_call.pop('arguments') + tool_calls.append(json.dumps(tool_call, ensure_ascii=False)) + return '\n'.join(tool_calls) + + +class Llama4AgentTemplate(Llama3AgentTemplate): + eom_token = '<|eom|>' + start_token = '<|header_start|>' + end_token = '<|header_end|>' + eot_token = '<|eot|>' + toolcall_pattern = r'(.+?)<\|eom\|>' diff --git a/swift/plugin/agent_template/qwen.py b/swift/plugin/agent_template/qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..6443a12d44e9e705ca5ea6a0fbe248bc093a2c21 --- /dev/null +++ b/swift/plugin/agent_template/qwen.py @@ -0,0 +1,130 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from .base import AgentKeyword, BaseAgentTemplate + +keyword = AgentKeyword( + action='✿FUNCTION✿:', + action_input='✿ARGS✿:', + observation='✿RESULT✿:', +) + + +class QwenEnAgentTemplate(BaseAgentTemplate): + keyword = keyword + + def _get_tool_names_descs(self, tools): + tool_names = [] + tool_descs = [] + for tool in tools: + tool_desc = self._parse_tool(tool, 'en') + tool_names.append(tool_desc.name_for_model) + tool_descs.append(f'### {tool_desc.name_for_human}\n\n' + f'{tool_desc.name_for_model}: {tool_desc.description_for_model} ' + f'Parameters: {tool_desc.parameters} {tool_desc.args_format}') + return tool_names, tool_descs + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names, tool_descs = self._get_tool_names_descs(tools) + return f"""{system} + +# Tools + +## You have access to the following tools: + +""" + '\n\n'.join(tool_descs) + f""" + +## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs: + +✿FUNCTION✿: The tool to use, should be one of [{','.join(tool_names)}] +✿ARGS✿: The input of the tool +✿RESULT✿: Tool results +✿RETURN✿: Reply based on tool results. Images need to be rendered as ![](url)""" # noqa + + +class QwenZhAgentTemplate(BaseAgentTemplate): + keyword = keyword + + def _get_tool_names_descs(self, tools): + tool_names = [] + tool_descs = [] + for tool in tools: + tool_desc = self._parse_tool(tool, 'zh') + tool_names.append(tool_desc.name_for_model) + tool_descs.append(f'### {tool_desc.name_for_human}\n\n' + f'{tool_desc.name_for_model}: {tool_desc.description_for_model} ' + f'输入参数:{tool_desc.parameters} {tool_desc.args_format}') + return tool_names, tool_descs + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names, tool_descs = self._get_tool_names_descs(tools) + return f"""{system} + +# 工具 + +## 你拥有如下工具: + +""" + '\n\n'.join(tool_descs) + f""" + +## 你可以在回复中插入零次、一次或多次以下命令以调用工具: + +✿FUNCTION✿: 工具名称,必须是[{','.join(tool_names)}]之一。 +✿ARGS✿: 工具输入 +✿RESULT✿: 工具结果 +✿RETURN✿: 根据工具结果进行回复,需将图片用![](url)渲染出来""" # noqa + + +class QwenEnParallelAgentTemplate(QwenEnAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names, tool_descs = self._get_tool_names_descs(tools) + return f"""{system} + +# Tools + +## You have access to the following tools: + +""" + '\n\n'.join(tool_descs) + f""" + +## Insert the following command in your reply when you need to call N tools in parallel: + +✿FUNCTION✿: The name of tool 1, should be one of [{','.join(tool_names)}] +✿ARGS✿: The input of tool 1 +✿FUNCTION✿: The name of tool 2 +✿ARGS✿: The input of tool 2 +... +✿FUNCTION✿: The name of tool N +✿ARGS✿: The input of tool N +✿RESULT✿: The result of tool 1 +✿RESULT✿: The result of tool 2 +... +✿RESULT✿: he result of tool N +✿RETURN✿: Reply based on tool results. Images need to be rendered as ![](url)""" # noqa + + +class QwenZhParallelAgentTemplate(QwenZhAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names, tool_descs = self._get_tool_names_descs(tools) + return f"""{system} + +# 工具 + +## 你拥有如下工具: + +""" + '\n\n'.join(tool_descs) + f""" + +## 你可以在回复中插入以下命令以并行调用N个工具: + +✿FUNCTION✿: 工具1的名称,必须是[{','.join(tool_names)}]之一 +✿ARGS✿: 工具1的输入 +✿FUNCTION✿: 工具2的名称 +✿ARGS✿: 工具2的输入 +... +✿FUNCTION✿: 工具N的名称 +✿ARGS✿: 工具N的输入 +✿RESULT✿: 工具1的结果 +✿RESULT✿: 工具2的结果 +... +✿RESULT✿: 工具N的结果 +✿RETURN✿: 根据工具结果进行回复,需将图片用![](url)渲染出来""" # noqa diff --git a/swift/plugin/agent_template/react.py b/swift/plugin/agent_template/react.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfa5b820c611f9651890e13705e37a3be3e0933 --- /dev/null +++ b/swift/plugin/agent_template/react.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from .base import BaseAgentTemplate + + +class ReactEnAgentTemplate(BaseAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names = [] + tool_descs = [] + for tool in tools: + tool_desc = self._parse_tool(tool, 'en') + tool_names.append(tool_desc.name_for_model) + tool_descs.append( + f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. ' + f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} ' + f'Parameters: {tool_desc.parameters} {tool_desc.args_format}') + + return """Answer the following questions as best you can. You have access to the following tools: + +""" + '\n\n'.join(tool_descs) + f""" + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{','.join(tool_names)}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! +""" + + +class ReactZnAgentTemplate(BaseAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + tool_names = [] + tool_descs = [] + for tool in tools: + tool_desc = self._parse_tool(tool, 'zh') + tool_names.append(tool_desc.name_for_model) + tool_descs.append(f'{tool_desc.name_for_model}: 调用此工具与 {tool_desc.name_for_human} API 进行交互。' + f'{tool_desc.name_for_human} 有什么用?{tool_desc.description_for_model} ' + f'输入参数:{tool_desc.parameters} {tool_desc.args_format}') + return """尽可能地回答以下问题。你可以使用以下工具: + +""" + '\n\n'.join(tool_descs) + f""" + +请按照以下格式进行: + +Question: 需要你回答的输入问题 +Thought: 你应该总是思考该做什么 +Action: 需要使用的工具,应该是[{','.join(tool_names)}]中的一个 +Action Input: 传入工具的内容 +Observation: 行动的结果 +... (这个Thought/Action/Action Input/Observation可以重复N次) +Thought: 我现在知道最后的答案 +Final Answer: 对原始输入问题的最终答案 + +现在开始! +""" diff --git a/swift/plugin/agent_template/toolbench.py b/swift/plugin/agent_template/toolbench.py new file mode 100644 index 0000000000000000000000000000000000000000..54404e9f8e9faa75e5b9ecac1110d371e49318ba --- /dev/null +++ b/swift/plugin/agent_template/toolbench.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +import json + +from .base import BaseAgentTemplate + + +class ToolBenchAgentTemplate(BaseAgentTemplate): + + def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: + for i, tool in enumerate(tools): + tools[i] = self.unwrap_tool(tool) + tools = json.dumps(tools, ensure_ascii=False) + return f"""You can use many tools(functions) to do the following task. +First I will give you the task description, and your task start. +At each step, you need to give your thought to analyze the status now and what to do next, \ +with a function call to actually execute your step. Your output should follow this format: +Thought: +Action: +Action Input: + +After the call, you will get the call result, and you are now in a new state. +Then you will analyze your status now, then decide what to do next... +After many (Thought-call) pairs, you finally perform the task, then you can give your final answer. +Remember: +1.the state change is irreversible, you can't go back to one of the former state, if you want to restart the task, \ +say \"I give up and restart\". +2.All the thought is short, at most in 5 sentence. +3.You can do more then one try, so if your plan is to continuously try some conditions, \ +you can do one of the conditions per try. +Let's Begin! +Task description: You should use functions to help handle the real time user queries. Remember: +1.ALWAYS call \"Finish\" function at the end of the task. And the final answer should contain enough information \ +to show to the user,If you can't handle the task, \ +or you find that function calls always fail(the function is not valid now), \ +use function Finish->give_up_and_restart. +2.Do not use origin tool names, use only subfunctions' names. +Specifically, you have access to the following APIs: {tools}""" diff --git a/swift/plugin/callback.py b/swift/plugin/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..01db43c9b014ae33d02e43bd6d3ee30eadbbeda5 --- /dev/null +++ b/swift/plugin/callback.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments + +from swift.utils import get_logger + +logger = get_logger() + + +class EarlyStopCallback(TrainerCallback): + """An early stop implementation""" + + def __init__(self, total_interval=3): + self.best_metric = None + self.interval = 0 + self.total_interval = total_interval + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + operator = np.greater if args.greater_is_better else np.less + if self.best_metric is None or operator(state.best_metric, self.best_metric): + self.best_metric = state.best_metric + else: + self.interval += 1 + + if self.interval >= self.total_interval: + logger.info(f'Training stop because of eval metric is stable at step {state.global_step}') + control.should_training_stop = True + + +extra_callbacks = [] +# This example shows a simple example of EarlyStop Callback, uncomment this to use +# extra_callbacks = [EarlyStopCallback()] diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad82a5deef5b373e4a55eddeb1d136b65f13b06 --- /dev/null +++ b/swift/plugin/loss.py @@ -0,0 +1,388 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from enum import Enum +from typing import Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from accelerate.utils import gather_object +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.utils import strtobool + + +class LossType: + loss_scale = 'loss_scale' + cosine_similarity = 'cosine_similarity' + contrastive = 'contrastive' + online_contrastive = 'online_contrastive' + infonce = 'infonce' + + +LOSS_MAPPING = {} + + +def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None): + loss_info = {} + + if loss_func is not None: + loss_info['loss_func'] = loss_func + LOSS_MAPPING[loss_type] = loss_info + return + + def _register_loss_func(loss_func: Callable) -> Callable: + loss_info['loss_func'] = loss_func + LOSS_MAPPING[loss_type] = loss_info + return loss_func + + return _register_loss_func + + +def ce_loss_func(outputs, labels): + logits = outputs.logits + device = logits.device + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:].to(device) + # Save memory + masks = shift_labels != -100 + shift_logits = shift_logits[masks] + shift_labels = shift_labels[masks] + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction='none') + loss = loss_fct(shift_logits, shift_labels) + return loss, masks + + +# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train +@register_loss_func(LossType.loss_scale) +def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: + """Loss func + + Args: + outputs: The model outputs + labels: The labels + loss_scale: The loss scale + num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100. + + Returns: + + """ + loss, masks = ce_loss_func(outputs, labels) + if loss_scale is not None: + shift_scale = loss_scale[..., 1:].to(masks.device) + shift_scale = shift_scale[masks] + loss = (shift_scale * loss) + if num_items_in_batch is None: + loss = loss.mean() + else: + # compat transformers>=4.46 + loss = loss.sum() / num_items_in_batch + return loss + + +def _parse_pair_sentence(outputs): + if isinstance(outputs, dict): + last_hidden_state = outputs['last_hidden_state'] + else: + last_hidden_state = outputs + batch_size = last_hidden_state.shape[0] + shape_len = len(last_hidden_state.shape) + first_sentence = list(range(0, batch_size, 2)) + second_sentence = list(range(1, batch_size, 2)) + if shape_len == 3: + sentence1 = last_hidden_state[first_sentence][:, 0].squeeze(dim=1) + sentence2 = last_hidden_state[second_sentence][:, 0].squeeze(dim=1) + else: + sentence1 = last_hidden_state[first_sentence] + sentence2 = last_hidden_state[second_sentence] + return sentence1, sentence2 + + +# Code borrowed from sentence_transformers +class SiameseDistanceMetric(Enum): + """The metric for the contrastive loss""" + + EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa + MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa + COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa + + +@register_loss_func(LossType.cosine_similarity) +def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: + cos_score_transformation = nn.Identity() + loss_fct = MSELoss() + sentence1, sentence2 = _parse_pair_sentence(outputs) + output = cos_score_transformation(torch.cosine_similarity(sentence1, sentence2)) + return loss_fct(output, labels.to(output.dtype).view(-1)) + + +@register_loss_func(LossType.contrastive) +def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: + sentence1, sentence2 = _parse_pair_sentence(outputs) + distance_metric = SiameseDistanceMetric.COSINE_DISTANCE + distances = distance_metric(sentence1, sentence2) + margin = 0.5 + labels = labels.to(sentence1.dtype) + losses = 0.5 * (labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2)) + return losses.mean() + + +def calculate_paired_metrics(embeddings, labels): + from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \ + paired_manhattan_distances + from scipy.stats import pearsonr, spearmanr + + embeddings1, embeddings2 = _parse_pair_sentence(embeddings) + cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) + manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) + euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) + dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] + + eval_pearson_cosine, _ = pearsonr(labels, cosine_scores) + eval_spearman_cosine, _ = spearmanr(labels, cosine_scores) + + eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances) + eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances) + + eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances) + eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances) + + eval_pearson_dot, _ = pearsonr(labels, dot_products) + eval_spearman_dot, _ = spearmanr(labels, dot_products) + + return { + 'pearson_cosine': eval_pearson_cosine, + 'pearson_euclidean': eval_pearson_manhattan, + 'pearson_manhattan': eval_pearson_euclidean, + 'pearson_dot_product': eval_pearson_dot, + 'spearman_cosine': eval_spearman_cosine, + 'spearman_euclidean': eval_spearman_manhattan, + 'spearman_manhattan': eval_spearman_euclidean, + 'spearman_dot_product': eval_spearman_dot, + } + + +def calculate_infonce_metrics(embeddings, labels): + from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \ + paired_manhattan_distances + from scipy.stats import pearsonr, spearmanr + hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) + use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')) + split_tensors = _parse_multi_negative_sentences(torch.tensor(embeddings), torch.tensor(labels), hard_negatives) + split_tensors = [t.numpy() for t in split_tensors] + can_batched = hard_negatives is not None + if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: + can_batched = True + all_similarity_matrix = [] + all_labels = [] + pos_neg_margins = [] + if not use_batch: + if can_batched: + sentences = np.stack(split_tensors, axis=0) + similarity_matrix = np.matmul(sentences[:, 0:1], sentences[:, 1:].transpose((0, 2, 1))).squeeze(1) + all_similarity_matrix.append(similarity_matrix) + labels = np.zeros_like(similarity_matrix) + labels[:, 0] = 1 + all_labels.append(labels) + else: + for tensor in split_tensors: + similarity_matrix = np.matmul(tensor[0], tensor[1:].T) + all_similarity_matrix.append(similarity_matrix) + labels = np.zeros_like(similarity_matrix) + labels[0] = 1 + all_labels.append(labels) + max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1) + pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item()) + else: + if can_batched: + sentences = np.stack(split_tensors, axis=0) + similarity_matrix = np.matmul(sentences[:, 0], sentences[:, 1:].reshape(-1, sentences.shape[2]).T) + all_similarity_matrix.append(similarity_matrix) + labels = np.zeros_like(similarity_matrix) + for row, col in enumerate(range(0, sentences.shape[0] * (sentences.shape[1] - 1), sentences.shape[1] - 1)): + labels[row, col] = 1 + all_labels.append(labels) + else: + all_tensors = [] + for tensor in split_tensors: + all_tensors.append(tensor[1:]) + sentences = np.concatenate(all_tensors, axis=0) + length = 0 + for idx, tensor in enumerate(split_tensors): + similarity_matrix = np.matmul(tensor[0], sentences.T) + all_similarity_matrix.append(similarity_matrix) + labels = np.zeros_like(similarity_matrix) + labels[length] = 1 + all_labels.append(labels) + length += tensor.shape[0] - 1 + max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1) + pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item()) + + similarity_matrix = np.concatenate(all_similarity_matrix, axis=0) + labels = np.concatenate(all_labels, axis=0) + if can_batched: + pos_scores = similarity_matrix[labels == 1].reshape(similarity_matrix.shape[0], -1) + neg_scores = similarity_matrix[labels == 0].reshape(similarity_matrix.shape[0], -1) + max_neg_scores = np.max(neg_scores, axis=-1) + pos_neg_margin = np.mean(pos_scores - max_neg_scores).item() + else: + pos_scores = similarity_matrix[labels == 1] + neg_scores = similarity_matrix[labels == 0] + pos_neg_margin = np.mean(pos_neg_margins) + + mean_neg = np.mean(neg_scores) + mean_pos = np.mean(pos_scores) + return {'margin': pos_neg_margin, 'mean_neg': mean_neg, 'mean_pos': mean_pos} + + +def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None): + split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist() + if isinstance(split_indices, int): + split_indices = [split_indices] + split_indices.append(len(labels)) + split_indices = np.array(split_indices) + np.array(list(range(len(split_indices)))) + split_tensors = [] + + for i in range(len(split_indices) - 1): + start = split_indices[i] + end = split_indices[i + 1] + split_part = sentences[start:end] + if hard_negatives is not None: + negatives = len(split_part) - 2 + assert negatives > 0 + if negatives > hard_negatives: + split_part = split_part[:hard_negatives + 2] + elif negatives < hard_negatives: + selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True) + selected += 1 # skip positive + split_part = torch.cat((split_part, split_part[selected]), dim=0) + split_tensors.append(split_part) + return split_tensors + + +@register_loss_func(LossType.infonce) +def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: + temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature + # calculate CE across the batch, meaning all samples will be negative except the matching positive + use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')) + hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) # how many negative prompts kept in one sample + # mask out fake negatives + infonce_mask_fake_negative = strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False')) + if hard_negatives is not None: + hard_negatives = int(hard_negatives) + from swift.utils import get_dist_setting + rank, _, world_size, _ = get_dist_setting() + # repeat of anchor(1)+positive(1)+negatives(n) + sentences = outputs['last_hidden_state'] + + if world_size > 1 and use_batch: + # gather all the sentences and labels across the gpus when calculate loss across all batches of all gpus + all_sentences = gather_object(sentences.unsqueeze(0)) + labels = gather_object(labels) + # override the gathered one + all_sentences[rank] = sentences + for idx in range(len(all_sentences)): + if idx == rank: + continue + # we don't calculate grad from other gpus + all_sentences[idx] = all_sentences[idx].detach().to(sentences.device) + sentences = torch.cat(all_sentences, dim=0) + labels = [tensor.to(sentences.device) for tensor in labels] + labels = torch.stack(labels, dim=0) + + # split tensors into single sample + # for example: batch_size=2 with tensor anchor(1)+positive(1)+negatives(3) + anchor(1)+positive(1)+negatives(2) + # labels will be [1,0,0,0,1,0,0], meaning 1 positive, 3 negatives, 1 positive, 2 negatives + split_tensors = _parse_multi_negative_sentences(sentences, labels, hard_negatives) + loss = 0 + can_batched = hard_negatives is not None + if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: + # all tensors have the same batch size + can_batched = True + if not use_batch: + # only calculate loss inside one sample + if can_batched: + # negative numbers are equal + # [B, neg+2, D] + sentences = torch.stack(split_tensors, dim=0) + # [B, 1, D] * [B, neg+1, D] + similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / temperature + # The positive one is the first element + labels = torch.zeros(len(split_tensors), dtype=torch.int64).to(sentences.device) + loss = nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels) + else: + # the negative numbers may be different, use for loop + for tensor in split_tensors: + # [D] * [neg+1, D] + similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / temperature + # The positive one is the first element + labels = torch.tensor(0).to(tensor.device) + loss += nn.CrossEntropyLoss()(similarity_matrix, labels) + # avg between all batches in one gpu + loss /= len(split_tensors) + else: + + def mask_fake_negative(sim_matrix, sim_labels): + thresholds = sim_matrix[torch.arange(sim_matrix.size(0)), sim_labels].view(-1, 1) + 0.1 + thresholds = thresholds.detach() + mask = sim_matrix > thresholds + sim_matrix[mask] = float('-inf') + + if can_batched: + # [B, neg+2, D] + sentences = torch.stack(split_tensors, dim=0) + # [B, D] * [B*(neg+1), D] + similarity_matrix = torch.matmul(sentences[:, 0].squeeze(1), sentences[:, + 1:].reshape(-1, sentences.size(2)).T) + labels = torch.tensor(range(0, + sentences.size(0) * (sentences.size(1) - 1), + sentences.size(1) - 1)).view(-1).to(sentences.device) + if infonce_mask_fake_negative: + mask_fake_negative(similarity_matrix, labels) + similarity_matrix = similarity_matrix / temperature + # every neg+1 is positive start from 0 + loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size # avoid duplicate + else: + all_tensors = [] + for tensor in split_tensors: + all_tensors.append(tensor[1:]) + # cat all neg+1 tensors + sentences = torch.cat(all_tensors, dim=0) + length = 0 + for idx, tensor in enumerate(split_tensors): + # [D] * [B*(neg+1), D], neg numbers are different + similarity_matrix = torch.matmul(tensor[0], sentences.T) / temperature + labels = torch.tensor(length).to(tensor.device) + loss += nn.CrossEntropyLoss()(similarity_matrix, labels) + # next positive is neg+1 + length += tensor.size(0) - 1 + loss /= len(split_tensors) + loss /= world_size # avoid duplicate + return loss + + +@register_loss_func(LossType.online_contrastive) +def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: + sentence1, sentence2 = _parse_pair_sentence(outputs) + distance_metric = SiameseDistanceMetric.COSINE_DISTANCE + distance_matrix = distance_metric(sentence1, sentence2) + negs = distance_matrix[labels == 0] + poss = distance_matrix[labels == 1] + + # select hard positive and hard negative pairs + negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())] + positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())] + + positive_loss = positive_pairs.pow(2).sum() + margin = 0.5 + negative_loss = F.relu(margin - negative_pairs).pow(2).sum() + loss = positive_loss + negative_loss + return loss + + +def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]: + if loss_type is None: + return None + return LOSS_MAPPING[loss_type]['loss_func'] diff --git a/swift/plugin/loss_scale/__init__.py b/swift/plugin/loss_scale/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..579be3b98ca209fb7f868a601bda14b64bbf561c --- /dev/null +++ b/swift/plugin/loss_scale/__init__.py @@ -0,0 +1 @@ +from .loss_scale import loss_scale_map diff --git a/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc b/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f4718ba71d1890ea37fb7711a266870ad79fa3c Binary files /dev/null and b/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc b/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1517249f17d4f4d1dbdca04d95c5e564922c7d76 Binary files /dev/null and b/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc differ diff --git a/swift/plugin/loss_scale/__pycache__/utils.cpython-310.pyc b/swift/plugin/loss_scale/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c90e923546b76c5a2a30222932a2e795b660e7c7 Binary files /dev/null and b/swift/plugin/loss_scale/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/plugin/loss_scale/config/agentflan.json b/swift/plugin/loss_scale/config/agentflan.json new file mode 100644 index 0000000000000000000000000000000000000000..2751fea02b15587835f21577221d155417d129ea --- /dev/null +++ b/swift/plugin/loss_scale/config/agentflan.json @@ -0,0 +1,22 @@ +{ + "response":{ + "Name:": [1.0, 3.0], + "Action:": [1.0, 3.0], + "ACTION:": [1.0,3.0], + "Tool:": [1.0, 3.0], + "Command": [1.0, 3.0], + "Arguments:": [1.0, 3.0], + "action input": [1.0, 3.0], + "ACTION_INPUT:":[1.0, 3.0], + "Action Input:": [1.0, 3.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0] + }, + "query":{ + "What is the tool you want to use": [3.0], + "What are the required parameter names": [3.0], + "What is the value of": [3.0], + "What are the required parameter names for this tool": [3.0] + } +} diff --git a/swift/plugin/loss_scale/config/alpha_umi.json b/swift/plugin/loss_scale/config/alpha_umi.json new file mode 100644 index 0000000000000000000000000000000000000000..fcdcbcb185066da0b768263562729d8361ebaa01 --- /dev/null +++ b/swift/plugin/loss_scale/config/alpha_umi.json @@ -0,0 +1,8 @@ +{ + "Action:": [2.0, 2.0], + "Action Input:": [2.0, 2.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0], + "Next:": [2,0, 2.0] +} diff --git a/swift/plugin/loss_scale/config/hermes.json b/swift/plugin/loss_scale/config/hermes.json new file mode 100644 index 0000000000000000000000000000000000000000..e8bfee3fc5d6cd8aa79c99f0f9b4fcd15b623645 --- /dev/null +++ b/swift/plugin/loss_scale/config/hermes.json @@ -0,0 +1,3 @@ +{ + ".+?": [2.0] +} diff --git a/swift/plugin/loss_scale/config/ignore_empty_think.json b/swift/plugin/loss_scale/config/ignore_empty_think.json new file mode 100644 index 0000000000000000000000000000000000000000..c7c2395fbb78294a543f09072620895e76ef1ea9 --- /dev/null +++ b/swift/plugin/loss_scale/config/ignore_empty_think.json @@ -0,0 +1,3 @@ +{ + "\n\n\n\n": [0.0] +} diff --git a/swift/plugin/loss_scale/config/qwen.json b/swift/plugin/loss_scale/config/qwen.json new file mode 100644 index 0000000000000000000000000000000000000000..731ba5340387e8a3467831877fdfb1cdd19fdc90 --- /dev/null +++ b/swift/plugin/loss_scale/config/qwen.json @@ -0,0 +1,6 @@ +{ + "✿FUNCTION✿:": [2.0, 2.0], + "✿ARGS✿:": [2.0, 2.0], + "✿RETURN✿:": [1.0, 1.0], + "✿RESULT✿:": [2.0, 0.0] +} diff --git a/swift/plugin/loss_scale/config/react.json b/swift/plugin/loss_scale/config/react.json new file mode 100644 index 0000000000000000000000000000000000000000..006f92948e1a6de28a1825fa2ef256dc1b09de81 --- /dev/null +++ b/swift/plugin/loss_scale/config/react.json @@ -0,0 +1,7 @@ +{ + "Action:": [2.0, 2.0], + "Action Input:": [2.0, 2.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0] +} diff --git a/swift/plugin/loss_scale/loss_scale.py b/swift/plugin/loss_scale/loss_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..1540169e00f3e14dba1c019536d50fa3f9536c6f --- /dev/null +++ b/swift/plugin/loss_scale/loss_scale.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List, Optional, Tuple + +import json + +from swift.llm import Messages +from swift.llm.template.utils import ContextType +from .utils import calculate_loss_scale + + +class LossScale: + loss_scale_config = None # path + + def __init__(self): + if self.loss_scale_config is not None: + path = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(path, 'config', self.loss_scale_config) + with open(config_path, 'r', encoding='utf-8') as json_file: + self.loss_scale_map = json.load(json_file) + else: + self.loss_scale_map = None + + def get_loss_scale(self, + context: str, + context_type: ContextType, + is_last_round: bool, + *, + query: Optional[str] = None) -> Tuple[List[str], List[float]]: + """Calculate loss scale + + Args: + context: The input context + context_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.) + is_last_round: If this is the last round of messages. + query: The query of this round. + + Returns: + A tuple, list of context and list of loss_scales + """ + if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}: + loss_scale = 1. + else: + loss_scale = 0. + return [context], [loss_scale] + + def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages, + **kwargs) -> Tuple[List[str], List[float]]: + res_context_list = [] + res_loss_scale = [] + i = 0 + n_round = len(messages) // 2 + for context, context_type in zip(context_list, context_types): + is_last_round = i + 1 == n_round + if context_type == ContextType.RESPONSE: + query = messages[2 * i]['content'] + assert context == messages[2 * i + 1]['content'] + kwargs = {'query': query} + i += 1 + new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs) + res_context_list += new_context + res_loss_scale += loss_scale + return res_context_list, res_loss_scale + + +class LastRoundLossScale(LossScale): + + def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs): + if context_type == ContextType.RESPONSE: + return [context], [float(is_last_round)] + return super().get_loss_scale(context, context_type, is_last_round) + + +class AgentFlanLossScale(LossScale): + loss_scale_config = 'agentflan.json' + + def get_loss_scale(self, + context: str, + context_type: ContextType, + is_last_round: bool, + *, + query: Optional[str] = None): + if context_type == ContextType.RESPONSE: + return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query']) + return super().get_loss_scale(context, context_type, is_last_round) + + +class REACTLossScale(LossScale): + loss_scale_config = 'react.json' + + def get_loss_scale(self, + context: str, + context_type: ContextType, + is_last_round: bool, + *, + query: Optional[str] = None): + if context_type == ContextType.RESPONSE: + return calculate_loss_scale(query, context, self.loss_scale_map) + return super().get_loss_scale(context, context_type, is_last_round) + + +class QwenLossScale(REACTLossScale): + loss_scale_config = 'qwen.json' + + +class HermesLossScale(REACTLossScale): + loss_scale_config = 'hermes.json' + + +class AlphaUmiLossScale(REACTLossScale): + loss_scale_config = 'alpha_umi.json' + + +class TrainAllLossScale(LossScale): + + def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs): + return [context], [1.] + + +class IgnoreEmptyThink(REACTLossScale): + loss_scale_config = 'ignore_empty_think.json' + + +# Add your loss scale here, use --loss_scale xxx to train +loss_scale_map = { + 'last_round': LastRoundLossScale(), + 'default': LossScale(), + 'all': TrainAllLossScale(), + 'ignore_empty_think': IgnoreEmptyThink(), + # agent + 'react': REACTLossScale(), + 'hermes': HermesLossScale(), + 'qwen': QwenLossScale(), + 'agentflan': AgentFlanLossScale(), + 'alpha_umi': AlphaUmiLossScale(), +} diff --git a/swift/plugin/loss_scale/utils.py b/swift/plugin/loss_scale/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d60c592a5d025e689d2a232648fa54d19ca71ff0 --- /dev/null +++ b/swift/plugin/loss_scale/utils.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Optional, Tuple + +from swift.llm.template import split_str_parts_by + + +def calculate_loss_scale(query: str, + response: str, + response_loss_scale_map: Dict[str, list], + query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]: + """Calculate the loss scale by splitting the agent response. + + This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf + + Agent response format: + + ```text + Thought: you should always think about what to do + Action: the action to take, should be one of the above tools[fire_recognition, + fire_alert, call_police, call_fireman] + Action Input: the input to the action + Observation: the result of the action + ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) + Thought: I now know the final answer + Final Answer: the final answer to the original input question + ``` + Returns: + A tuple of agent response parts and their weights. + """ + # query loss scale map + if query_loss_scale_map is not None: + for key in query_loss_scale_map.keys(): + if key in query: + if isinstance(query_loss_scale_map[key], (float, int)): + query_loss_scale_map[key] = [query_loss_scale_map[key]] + loss_scale_value = query_loss_scale_map[key][0] + return [response], [float(loss_scale_value)] + delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2] + if delimiters: + agent_parts = split_str_parts_by(response, delimiters) + else: + regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1] + agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True) + weights = [] + agent_content = [] + for c in agent_parts: + if c['key'] in response_loss_scale_map: + loss_scale = response_loss_scale_map[c['key']] + assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}' + if len(loss_scale) == 1: + weights += loss_scale + agent_content.append(c['content']) + else: + weights += loss_scale + agent_content += [c['key'], c['content']] + else: + weights.append(1.) + agent_content.append(c['content']) + return agent_content, weights diff --git a/swift/plugin/metric.py b/swift/plugin/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..410449815c27d6591290b4e4458888d758721a14 --- /dev/null +++ b/swift/plugin/metric.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from abc import ABC, abstractmethod +from typing import Dict, List, Literal + +import numpy as np +import torch +from transformers.trainer_utils import EvalPrediction + +from swift.utils import Serializer, get_logger + +logger = get_logger() + + +class Metric(ABC): + + def __init__(self): + self._default = {} + self._default_factory = {} + + def add_state(self, name: str, default=None, default_factory=None) -> None: + if not hasattr(self, '_default'): + raise AttributeError('Please call super().__init__() first.') + if default is None: + self._default_factory[name] = default_factory + assert name not in self._default, f'self._default: {self._default}' + default = default_factory() + else: + self._default[name] = default + assert name not in self._default_factory, f'self._default_factory: {self._default_factory}' + setattr(self, name, default) + + def reset(self): + for k, v in self._default.items(): + setattr(self, k, v) + for k, v in self._default_factory.items(): + setattr(self, k, v()) + + @abstractmethod + def update(self, *args, **kwargs): + pass + + @abstractmethod + def compute(self): + pass + + +class InferStats(Metric): + + def __init__(self): + super().__init__() + self.add_state('start_runtime', default_factory=lambda: time.perf_counter()) + self.add_state('num_prompt_tokens', default_factory=dict) + self.add_state('num_generated_tokens', default_factory=dict) + + def update(self, output): + id_ = output.id + self.num_prompt_tokens[id_] = output.usage.prompt_tokens + self.num_generated_tokens[id_] = output.usage.completion_tokens + + def compute(self): + runtime = time.perf_counter() - self.start_runtime + num_samples = len(self.num_generated_tokens) + num_generated_tokens = sum(self.num_generated_tokens.values()) + return { + 'num_prompt_tokens': sum(self.num_prompt_tokens.values()), + 'num_generated_tokens': num_generated_tokens, + 'num_samples': num_samples, + 'runtime': runtime, + 'samples/s': num_samples / runtime, + 'tokens/s': num_generated_tokens / runtime, + } + + +class MeanMetric(Metric): + + def __init__(self, nan_value=0): + super().__init__() + self.nan_value = nan_value + self.add_state('state', default=0.) + self.add_state('count', default=0) + + def update(self, state: torch.Tensor): + if isinstance(state, (torch.Tensor, np.ndarray)): + state = state.tolist() + + if isinstance(state, (list, tuple)): + count = len(state) + state = sum(state) + else: + count = 1 + + self.state += state + self.count += count + + def compute(self): + return { + 'value': self.state / self.count if self.count > 0 else self.nan_value, + } + + +def compute_rouge_bleu(preds: List[str], labels: List[str]): + import jieba + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + from rouge.rouge import Rouge + score_dict = {key: MeanMetric() for key in ['rouge-1', 'rouge-2', 'rouge-l', 'bleu-4']} + + for pred, label in zip(preds, labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + if not hypothesis or not reference: + continue + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))[0] + for k, v in scores.items(): + score_dict[k].update(v['f']) + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict['bleu-4'].update(bleu_score) + + return {k: round(v.compute()['value'] * 100, 6) for k, v in score_dict.items()} + + +def compute_nlg_metrics(prediction) -> Dict[str, float]: + preds, labels = prediction[0], prediction[1] + new_preds, new_labels = [], [] + for i in range(preds.shape[0]): + new_preds.append(Serializer.from_tensor(preds[i])) + new_labels.append(Serializer.from_tensor(labels[i])) + return compute_rouge_bleu(new_preds, new_labels) + + +def compute_acc(preds, + labels, + *, + acc_strategy: Literal['token', 'seq'] = 'token', + is_encoder_decoder: bool = False) -> Dict[str, List[float]]: + + if isinstance(preds, torch.Tensor): + if torch.is_floating_point(labels): + return {} + preds = preds.cpu().numpy() + labels = labels.cpu().numpy() + if preds.ndim >= 2 and not is_encoder_decoder: + labels = labels[..., 1:] + preds = preds[..., :-1] + if np.issubdtype(labels.dtype, np.floating) or preds.shape != labels.shape: + return {} + + masks = labels != -100 + if acc_strategy == 'token' or preds.ndim == 1: + acc_list = (preds[masks] == labels[masks]).tolist() + else: + acc_list = [] + for i, m in enumerate(masks): + acc_list.append(np.all(preds[i, m] == labels[i, m])) + return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list} + + +def compute_acc_metrics(eval_prediction: EvalPrediction, + *, + acc_strategy: Literal['token', 'seq'] = 'token', + is_encoder_decoder: bool = False) -> Dict[str, float]: + + metric = compute_acc( + eval_prediction.predictions, + eval_prediction.label_ids, + acc_strategy=acc_strategy, + is_encoder_decoder=is_encoder_decoder) + if len(metric) == 0: + return {} + return {k: sum(v) / len(v) for k, v in metric.items()} + + +def preprocess_logits_for_acc(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if isinstance(logits, (list, tuple)): + logits = logits[0] + preds = logits.argmax(dim=-1) + return preds + + +# Add your own metric calculation method here, use --metric xxx to train +METRIC_MAPPING = { + 'acc': (compute_acc_metrics, preprocess_logits_for_acc), + 'nlg': (compute_nlg_metrics, None), +} + + +def get_metric(metric: str): + return METRIC_MAPPING[metric] diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9881892eaf26e4ee7c2b2ebd7702264f748f03 --- /dev/null +++ b/swift/plugin/multi_turn.py @@ -0,0 +1,42 @@ +def check_math_result_and_give_tips(inputs): + from .orm import MathAccuracy + acc = MathAccuracy() + # a trick + prompt = 'But wait... It seems I made a mistake,' + contents = [input['messages'][-1]['content'] for input in inputs] + rewards = acc(contents, [input['solution'] for input in inputs]) + for reward, input in zip(rewards, inputs): + content = input['messages'][-1]['content'] + if reward < 1 and prompt not in content: + if '' in content: + content = content[:content.index('')] + if '' in content: + content = content[:content.index('')] + content += prompt + input['messages'][-1]['content'] = content + input['finished'] = False + else: + input['finished'] = True + return inputs + + +def check_math_result_and_give_tips_multi_turn(inputs): + from .orm import MathAccuracy + acc = MathAccuracy() + prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' + contents = [input['messages'][-1]['content'] for input in inputs] + rewards = acc(contents, [input['solution'] for input in inputs]) + for reward, input in zip(rewards, inputs): + content = input['messages'][-2]['content'] + if reward < 1 and prompt not in content: + input['messages'].append({'role': 'user', 'content': prompt}) + input['finished'] = False + else: + input['finished'] = True + return inputs + + +multi_turns = { + 'math_tip_trick': check_math_result_and_give_tips, + 'math_tip_trick_multi_turn': check_math_result_and_give_tips_multi_turn, +} diff --git a/swift/plugin/optimizer.py b/swift/plugin/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..05a4b6ef8da78a9cf0662d04b887ca5f84aafb54 --- /dev/null +++ b/swift/plugin/optimizer.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +import sys + +from transformers import Trainer + +from swift.trainers.optimizers.galore import create_optimizer_and_scheduler +from swift.utils import get_dist_setting + + +def calculate_max_steps(args: 'TrainArguments', dataset) -> int: + if args.max_steps and args.max_steps > 0: + max_steps = args.max_steps + else: + len_dataset = len(dataset) + _, _, world_size, _ = get_dist_setting() + total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size + num_update_steps_per_epoch = len_dataset // total_train_batch_size + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + return max_steps + + +def create_galore_optimizer(args, model, dataset): + training_steps = calculate_max_steps(args, dataset) + optimizer, lr_scheduler = create_optimizer_and_scheduler( + model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay) + # trainer cannot serialize galore_config + args.galore_config = None + return optimizer, lr_scheduler + + +def create_lorap_optimizer(args, model, dataset): + optimizer_grouped_parameters = None + if hasattr(model, 'create_optimizer_param_groups'): + # Lora+ parameter groups + optimizer_grouped_parameters = model.create_optimizer_param_groups( + lr=args.learning_rate, weight_decay=args.weight_decay) + + if optimizer_grouped_parameters is None: + # Default parameter groups + decay_parameters = Trainer.get_decay_parameter_names(None, model) + optimizer_grouped_parameters = [ + { + 'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)], + 'weight_decay': args.weight_decay, + }, + { + 'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + 'weight_decay': 0.0, + }, + ] + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) + return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None + + +def create_muon_optimizer(args, model, dataset): + from swift.llm import git_clone_github, get_model_arch + if not args.local_repo_path: + args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git') + sys.path.append(os.path.join(args.local_repo_path, 'examples')) + from toy_train import Muon + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(' ', '').split(','): + key, value = mapping.split('=') + optim_args[key] = value + + model_arch = get_model_arch(model.model_meta.model_arch) + embed_key = model_arch.embedding or 'embed_tokens' + lm_head_key = model_arch.lm_head or 'lm_head' + muon_params = [ + p for n, p in model.named_parameters() + if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n + ] + adamw_params = [ + p for n, p in model.named_parameters() + if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n) + ] + + return Muon( + lr=args.learning_rate, + wd=args.weight_decay, + muon_params=muon_params, + adamw_params=adamw_params, + adamw_betas=(args.adam_beta1, args.adam_beta2), + adamw_eps=args.adam_epsilon, + **optim_args, + ), None + + +# Add your own optimizers here, use --optimizer xxx to train +optimizers_map = { + 'galore': create_galore_optimizer, + 'lorap': create_lorap_optimizer, + 'muon': create_muon_optimizer, +} diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f1980f9067eab862bae2e01d09129d0d4fa750 --- /dev/null +++ b/swift/plugin/orm.py @@ -0,0 +1,406 @@ +import os +import re +from typing import Dict, List, Union + +import json + +from swift.llm import InferRequest + + +class ORM: + + def __call__(self, **kwargs) -> List[float]: + raise NotImplementedError + + +class ReactORM(ORM): + + @staticmethod + def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list): + f1 = [] + for i in range(len(action_pred)): + ref_action = action_ref[i] + pred_action = action_pred[i] + + ref_input = ref_list[i] + cand_input = cand_list[i] + + ref_is_json = False + try: + ref_input_json = json.loads(ref_input) + ref_is_json = True + except Exception: + ref_input_json = ref_input + + cand_is_json = False + try: + cand_input_json = json.loads(cand_input) + cand_is_json = True + except Exception: + cand_input_json = cand_input + + if ref_action != pred_action or (ref_is_json ^ cand_is_json): + f1.append(0) + elif not ref_is_json and not cand_is_json: + rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json]) + if rougel is None or rougel < 10: + f1.append(0) + elif 10 <= rougel < 20: + f1.append(0.1) + else: + f1.append(1) + else: + if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict): + # This cannot be happen, but: + # line 62, in evaluate_action_reward + # for k, v in ref_input_json.items(): + # AttributeError: 'str' object has no attribute 'items' + # print(f'>>>>>>ref_input_json: {ref_input_json}, cand_input_json: {cand_input_json}') + f1.append(0) + continue + + half_match = 0 + full_match = 0 + if ref_input_json == {}: + if cand_input_json == {}: + f1.append(1) + else: + f1.append(0) + else: + for k, v in ref_input_json.items(): + if k in cand_input_json.keys(): + if cand_input_json[k] == v: + full_match += 1 + else: + half_match += 1 + + recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30) + precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30) + try: + f1.append((2 * recall * precision) / (recall + precision)) + except Exception: + f1.append(0.0) + + if f1[0] == 1.0: + return True + else: + return False + + @staticmethod + def parse_action(text): + if 'Action Input:' in text: + input_idx = text.rindex('Action Input:') + action_input = text[input_idx + len('Action Input:'):].strip() + else: + action_input = '{}' + + if 'Action:' in text: + action_idx = text.rindex('Action:') + action = text[action_idx + len('Action:'):].strip() + if 'Action Input:' in action: + input_idx = action.index('Action Input:') + action = action[:input_idx].strip() + else: + action = 'none' + return action, action_input + + @staticmethod + def parse_output(text): + action, action_input = ReactORM.parse_action(text) + return action, action_input + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]: + rewards = [] + if not isinstance(infer_requests[0], str): + predictions = [request['messages'][-1]['content'] for request in infer_requests] + else: + predictions = infer_requests + for prediction, ground_truth in zip(predictions, solution): + if prediction.endswith('Observation:'): + prediction = prediction[:prediction.index('Observation:')].strip() + action_ref = [] + action_input_ref = [] + action_pred = [] + action_input_pred = [] + reference = ground_truth + prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip() + ref_action, ref_input = ReactORM.parse_output(reference) + pred_action, pred_input = ReactORM.parse_output(prediction) + action_ref.append(ref_action) + action_input_ref.append(ref_input) + if pred_action is None: + action_pred.append('none') + else: + action_pred.append(pred_action) + + if pred_input is None: + action_input_pred.append('{}') + else: + action_input_pred.append(pred_input) + + reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref) + rewards.append(float(reward)) + return rewards + + @staticmethod + def evaluate_rougel(cand_list: list, ref_list: list): + if len(ref_list) == 0: + return None + try: + from rouge import Rouge + rouge = Rouge() + rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True) + rougel = rouge_score['rouge-l']['f'] + return rougel + except Exception: + return None + + +class MathORM(ORM): + + def __init__(self): + from transformers.utils import strtobool + self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False')) + if self.use_opencompass: + from opencompass.datasets.math import MATHEvaluator + self.evaluator = MATHEvaluator() + + @staticmethod + def check_terminate(answers: Union[str, List[str]]) -> List[bool]: + if isinstance(answers, str): + answers = [answers] + results = [] + for answer in answers: + results.append('\\boxed' in answer) + return results + + @staticmethod + def extract_boxed_result(text): + pattern = r'\\boxed{([^}]*)}' + match = re.search(pattern, text) + if match: + return match.group(1).strip() + else: + return text + + @staticmethod + def clean_latex(latex_str): + latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str) + latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '') + return latex_str.strip() + + @staticmethod + def parse_expression(latex_str): + from sympy import simplify + from sympy.parsing.latex import parse_latex + try: + expr = parse_latex(latex_str) + return simplify(expr) + except Exception: + return None + + @staticmethod + def compare_consecutive(first, second): + cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]] + parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list] + if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'): + value = parsed_exprs[0].equals(parsed_exprs[1]) + else: + value = parsed_exprs[0] == parsed_exprs[1] + if value is None: + value = False + return value + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], + **kwargs) -> List[float]: + rewards = [] + predictions = [request.messages[-1]['content'] for request in infer_requests] + for prediction, ground_truth in zip(predictions, ground_truths): + if '# Answer' in prediction: + prediction = prediction.split('# Answer')[1] + if '# Answer' in ground_truth: + ground_truth = ground_truth.split('# Answer')[1] + prediction = prediction.strip() + ground_truth = ground_truth.strip() + prediction = MathORM.extract_boxed_result(prediction) + ground_truth = MathORM.extract_boxed_result(ground_truth) + if self.use_opencompass: + reward = self.evaluator.is_equiv(prediction, ground_truth) + else: + reward = MathORM.compare_consecutive(prediction, ground_truth) + rewards.append(float(reward)) + return rewards + + +class MathAccuracy(ORM): + + def __init__(self): + import importlib.util + assert importlib.util.find_spec('math_verify') is not None, ( + "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") + + def __call__(self, completions, solution, **kwargs) -> List[float]: + from latex2sympy2_extended import NormalizationConfig + from math_verify import LatexExtractionConfig, parse, verify + rewards = [] + for content, sol in zip(completions, solution): + gold_parsed = parse(sol, extraction_mode='first_match') + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode='first_match', + ) + # edge case + try: + reward = float(verify(gold_parsed, answer_parsed)) + except Exception: + reward = 0.0 + else: + # If the gold solution is not parseable, we reward 0 to skip this example + reward = 0.0 + rewards.append(reward) + return rewards + + +class Format(ORM): + + def __call__(self, completions, **kwargs) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r'^.*?\s*.*?(?![\s\S])' + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] + return [1.0 if match else 0.0 for match in matches] + + +class ReActFormat(ORM): + + def __call__(self, completions, **kwargs) -> List[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r'^.*?\s*Action:.*?Action Input:.*?$' + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] + return [1.0 if match else 0.0 for match in matches] + + +class CosineReward(ORM): + # https://arxiv.org/abs/2502.03373 + def __init__(self, + tokenizer=None, + cosine_min_len_value_wrong: float = -0.5, + cosine_max_len_value_wrong: float = 0.0, + cosine_min_len_value_correct: float = 1.0, + cosine_max_len_value_correct: float = 0.5, + cosine_max_len: int = 1000, + accuracy_orm=None): + self.tokenizer = tokenizer + self.min_len_value_wrong = cosine_min_len_value_wrong + self.max_len_value_wrong = cosine_max_len_value_wrong + self.min_len_value_correct = cosine_min_len_value_correct + self.max_len_value_correct = cosine_max_len_value_correct + self.max_len = cosine_max_len + self.accuracy_orm = accuracy_orm or MathAccuracy() + + @staticmethod + def cosfn(t, T, min_value, max_value): + import math + return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2 + + def __call__(self, completions, solution, **kwargs) -> List[float]: + acc_rewards = self.accuracy_orm(completions, solution, **kwargs) + rewards = [] + for content, acc_reward in zip(completions, acc_rewards): + is_correct = acc_reward >= 1. + if is_correct: + # Swap min/max for correct answers + min_value = self.max_len_value_correct + max_value = self.min_len_value_correct + else: + min_value = self.max_len_value_wrong + max_value = self.min_len_value_wrong + gen_len = len(self.tokenizer.encode(content)) + reward = self.cosfn(gen_len, self.max_len, min_value, max_value) + rewards.append(reward) + return rewards + + +class RepetitionPenalty(ORM): + # https://arxiv.org/abs/2502.03373 + def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0): + self.ngram_size = repetition_n_grams + self.max_penalty = repetition_max_penalty + + @staticmethod + def zipngram(text: str, ngram_size: int): + words = text.lower().split() + return zip(*[words[i:] for i in range(ngram_size)]) + + def __call__(self, completions, **kwargs) -> List[float]: + """ + reward function the penalizes repetitions + + Args: + completions: List of model completions + """ + rewards = [] + for completion in completions: + if completion == '': + rewards.append(0.0) + continue + if len(completion.split()) < self.ngram_size: + rewards.append(0.0) + continue + + ngrams = set() + total = 0 + for ng in self.zipngram(completion, self.ngram_size): + ngrams.add(ng) + total += 1 + + scaling = 1 - len(ngrams) / total + reward = scaling * self.max_penalty + rewards.append(reward) + return rewards + + +class SoftOverlong(ORM): + + def __init__(self, tokenizer, soft_max_length, soft_cache_length): + self.tokenizer = tokenizer + assert soft_cache_length < soft_max_length + self.soft_max_length = soft_max_length + self.soft_cache_length = soft_cache_length + + def __call__(self, completions, **kwargs) -> List[float]: + rewards = [] + for completion in completions: + completion_length = len(self.tokenizer.encode(completion)) + expected_len = self.soft_max_length - self.soft_cache_length + exceed_len = completion_length - expected_len + rewards.append(min(-exceed_len / self.soft_cache_length, 0)) + return rewards + + +orms = { + 'toolbench': ReactORM, + 'math': MathORM, + 'accuracy': MathAccuracy, + 'format': Format, + 'react_format': ReActFormat, + 'cosine': CosineReward, + 'repetition': RepetitionPenalty, + 'soft_overlong': SoftOverlong, +} diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2b833128f4faefc18b4b4cddf204501fcd4a9a --- /dev/null +++ b/swift/plugin/prm.py @@ -0,0 +1,154 @@ +import os +from typing import Any, Dict, List, Union + +import json + +from swift.llm import InferRequest + + +class PRM: + + def __call__(self, **kwargs) -> List[Any]: + raise NotImplementedError + + +SYSTEM = """ +You are a process reward model, give the reward value of the answer, you must follow the instructions below: + +1. Output a float reward value between -1.0 and 1.0, -1.0 means the worst answer, 1.0 means the best answer, please think step by step to give your reasons and thoughts, but the reward must appare at the end with this format: **Reward: your-reward-value**. + +2. The answer may be incomplete, you must give the reward by the existing part of the answer, taking into account semantic coherence, logical correctness, and clarity. + +3. A ground truth answer will be given to you, it may be not the best one, consider it as a reference example. + +Begin! +""" # noqa + +QUERY = """ +The original question or the previous conversation: + +#query# + +Here is the ground truth as the reference: + +#ground_truth# + +Given the upper information, give your reward(-1.0~1.0) of the following answer: + +#response# +""" + + +class QwenMaxPRM(PRM): + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], + **kwargs) -> List[float]: + # TODO: check request_config + rewards = [] + + from openai import OpenAI + + client = OpenAI( + api_key=os.getenv('DASHSCOPE_API_KEY'), + base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', + ) + + for request, ground_truth in zip(infer_requests, ground_truths): + previous = request['messages'][:-1] + if previous[0]['role'] == 'system': + previous = previous[1:] + + assert request['messages'][-1]['role'] == 'assistant' + query = QUERY.replace('#query#', json.dumps(previous)) + query = query.replace('#ground_truth#', ground_truth) + query = query.replace('#response#', request['messages'][-1]['content']) + messages = [ + { + 'role': 'system', + 'content': SYSTEM + }, + { + 'role': 'user', + 'content': query + }, + ] + completion = client.chat.completions.create( + model='qwen-max', + messages=messages, + ) + + content = completion.choices[0].message.content + if 'Reward:' not in content: + rewards.append(0.) + else: + try: + reward = float(content.split('Reward:')[1].strip().replace('*', '')) + rewards.append(reward) + except Exception: + rewards.append(0.) + + return rewards + + +class ClientPRM(PRM): + + def __init__(self, api_key=None, base_url=None, model=None): + from swift.llm import InferClient + import os + if api_key is None: + api_key = os.getenv('DASHSCOPE_API_KEY') + if base_url is None: + base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + if model is None: + model = 'qwen-plus' + self.infer_engine = InferClient(base_url=base_url, api_key=api_key) + self.infer_engine.strict = False + self.infer_kwargs = { + 'model': model, + } + + def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], + **kwargs) -> List[float]: + prm_infer_requests = [] + request_config = kwargs.get('request_config') + for request, ground_truth in zip(infer_requests, ground_truths): + previous = request['messages'][:-1] + if previous[0]['role'] == 'system': + previous = previous[1:] + + assert request['messages'][-1]['role'] == 'assistant' + query = QUERY.replace('#query#', json.dumps(previous)) + query = query.replace('#ground_truth#', ground_truth) + query = query.replace('#response#', request['messages'][-1]['content']) + messages = [ + { + 'role': 'system', + 'content': SYSTEM + }, + { + 'role': 'user', + 'content': query + }, + ] + + prm_infer_requests.append(InferRequest(messages=messages)) + + responses = self.infer_engine.infer(prm_infer_requests, request_config=request_config, **self.infer_kwargs) + rewards = [] + for response in responses: + content = response.choices[0].message.content + if 'Reward:' not in content: + rewards.append(0.) + else: + try: + reward = float(content.split('Reward:')[1].strip().replace('*', '')) + rewards.append(reward) + except Exception: + rewards.append(0.) + return rewards + + +prms = { + 'qwen_max': QwenMaxPRM, + 'client': ClientPRM, +} diff --git a/swift/plugin/rm_plugin.py b/swift/plugin/rm_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..136223542992a01e574bd80418fec1e5bc8a505a --- /dev/null +++ b/swift/plugin/rm_plugin.py @@ -0,0 +1,229 @@ +import re +import textwrap +from copy import deepcopy +from typing import Dict, List + +import torch + +from swift.llm import PtEngine, RequestConfig, Template, to_device +from swift.llm.infer.protocol import ChatCompletionResponse +from swift.utils import get_logger + +logger = get_logger() + + +class DefaultRMPlugin: + """ + Default Reward Model Plugin + + This class implements the default processing logic for reward models. + It assumes that `self.model` is a classification model with a value head(output dimmension 1). + The first logits value from the model's output is used as the reward score. + """ + + def __init__(self, model, template): + self.model = model + self.template: Template = template + + def __call__(self, inputs): + batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] + reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) + reward_inputs.pop('labels') + + with torch.inference_mode(): + return self.model(**reward_inputs).logits[:, 0] + + +class GenRMPlugin(DefaultRMPlugin): + + def __init__(self, model, template): + """ + Generative Reward Model Plugin Example. + + This method sets up the reward model plugin by initializing the PtEngine for efficient inference, + configuring the request parameters, and defining the system prompt that guides the reward model in + evaluating responses. + + Args: + model (torch.nn.Module): The generative reward model. + template (Template): The template used for encoding input data. + """ + + super().__init__(model, template) + # initilize PTEngine to infer + self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit + self.request_config = RequestConfig() # customise your request config here + self.system = textwrap.dedent(""" + Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant. + Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct. + Before finishing your response, please assign a reward using the following format: + + Reward: {reward} + + For example: + Reward: 0.85 + """) # noqa + + def __call__(self, inputs): + """ + Compute reward scores for the provided inputs. + + This method processes each input by converting dialogue messages into a query, sending the query to the + reward model for inference, and extracting the reward scores from the model's responses. The final reward + for each input is the average of all extracted scores. + Args: + inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing: + - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes: + - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). + - 'content' (str): The content of the message. + - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images'). + Returns: + torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,), + where N is the number of input requests. + """ + + rm_inputs = self.prepare_rm_inputs(inputs) + results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False) + rewards = self.compute_rewards(results) + return torch.tensor(rewards, dtype=torch.float32) + + def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]: + """ + Prepare inputs for the reward model by converting messages into queries. + + Args: + inputs (List[Dict]): A list of input requests. + + Returns: + List[Dict]: Processed inputs for the reward model. + """ + rm_inputs = [] + for idx, infer_request in enumerate(inputs): + # Deep copy to prevent modification of original input + rm_infer_request = deepcopy(infer_request) + + # Extract and convert messages to a single query string + messages = rm_infer_request.get('messages') + query = self.messages_to_query(messages) + + # Construct new messages tailored for the reward model + rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}] + + # Update the messages in the reward infer request + rm_infer_request['messages'] = rm_messages + rm_inputs.append(rm_infer_request) + return rm_inputs + + @staticmethod + def extract_reward(model_output: str) -> float: + """ + Extract the reward score from the model's output. + + Args: + model_output (str): The model's output string, expected to follow the format "Reward: {reward}". + + Returns: + float: The extracted reward score. + + Raises: + ValueError: If the reward score cannot be extracted or the format is incorrect. + """ + match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output) + if match: + return float(match.group(1)) + else: + logger.warning("Unable to extract reward score from the model's output, set reward to 0") + return None + + @staticmethod + def messages_to_query(messages): + """ + Compress a list of message dictionaries into a single query string. + + Args: + messages (list[dict]): A list of message dictionaries, each containing: + - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). + - 'content' (str): The content of the message. + + Returns: + str: A single string that concatenates all messages in a formatted manner. + + Example: + >>> messages = [ + ... {'role': 'user', 'content': 'Hello, how are you?'}, + ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'}, + ... {'role': 'user', 'content': 'Can you help me with my homework?'} + ... ] + >>> print(messages_to_query(messages)) + User: Hello, how are you? + Assistant: I am fine, thank you! How can I assist you today? + User: Can you help me with my homework? + """ + # Initialize an empty list to hold formatted messages + formatted_messages = [] + + # Define a mapping for role capitalization if needed + role_mapping = { + 'user': 'User', + 'assistant': 'Assistant', + 'system': 'System' + # Add more roles here as needed + } + + for idx, message in enumerate(messages): + if not isinstance(message, dict): + raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.') + + # Extract 'role' and 'content' from each message + role = message.get('role') + content = message.get('content') + if not content: + continue + + # Capitalize the role using the mapping, default to capitalized original role + role_formatted = role_mapping.get(role.lower(), role.capitalize()) + + # Append the formatted message to the list + formatted_messages.append(f'{role_formatted}: {content}') + + # Join all formatted messages with newline characters + query = '\n'.join(formatted_messages) + + return query + + def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]: + """ + Compute average reward scores from the reward model's outputs. + + Args: + results (List[ChatCompletionResponse]): A list of results from the reward model. + + Returns: + List[float]: A list of average reward scores. + """ + rewards = [] + for idx, output in enumerate(results): + try: + cur_rewards = [] + for choice in output.choices: + response = choice.message.content + reward = self.extract_reward(response) + cur_rewards.append(reward) + cur_rewards = [r for r in cur_rewards if r is not None] + if cur_rewards: + average_reward = sum(cur_rewards) / len(cur_rewards) + else: + average_reward = 0.0 + logger.warning('No valid rewards extracted. Assigning reward score of 0.0.') + + rewards.append(average_reward) + except Exception as e: + logger.error(f'Error computing reward: {e}') + rewards.append(0.0) # Assign default reward score on failure + return rewards + + +rm_plugins = { + 'default': DefaultRMPlugin, + 'genrm': GenRMPlugin, +} diff --git a/swift/plugin/tuner.py b/swift/plugin/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cb44d5251d92749f8aeef189df4a3f572b506e --- /dev/null +++ b/swift/plugin/tuner.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Optional + +import torch +from peft import IA3Config, PeftModel, get_peft_model + +from swift.llm import MODEL_ARCH_MAPPING, ModelKeys +from swift.utils import find_all_linears + + +class Tuner: + + @staticmethod + def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: + """Prepare a new model with a tuner + + Args: + args: The training arguments + model: The model instance + + Returns: + The wrapped model + """ + raise NotImplementedError + + @staticmethod + def save_pretrained( + model: torch.nn.Module, + save_directory: str, + state_dict: Optional[dict] = None, + safe_serialization: bool = True, + **kwargs, + ) -> None: + """Save when save_steps reaches + + Args: + model: The wrapped model by `prepare_model` + save_directory: The directory to save + safe_serialization: Use safetensors or not + """ + raise NotImplementedError + + @staticmethod + def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: + """Load the ckpt_dir + + Args: + model: The original model instance. + model_id: The model id or ckpt_dir to load + Returns: + The wrapped model instance + """ + raise NotImplementedError + + +class PeftTuner(Tuner): + + @staticmethod + def save_pretrained( + model: torch.nn.Module, + save_directory: str, + state_dict: Optional[dict] = None, + safe_serialization: bool = True, + **kwargs, + ) -> None: + model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs) + + @staticmethod + def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module: + return PeftModel.from_pretrained(model, model_id, **kwargs) + + +# Here gives a simple example of IA3 +class IA3(PeftTuner): + + @staticmethod + def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: + model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch] + ia3_config = IA3Config( + target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*') + return get_peft_model(model, ia3_config) + + +class DummyTuner(PeftTuner): + + @staticmethod + def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module: + return model + + +# Add your own tuner here, use --train_type xxx to begin +extra_tuners = {'ia3': IA3, 'dummy': DummyTuner} diff --git a/swift/trainers/__init__.py b/swift/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16ae3dfe72c7ad9b0041e25932103e3495f60019 --- /dev/null +++ b/swift/trainers/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, + SchedulerType) + +from swift.utils.import_utils import _LazyModule +from . import callback + +try: + # https://github.com/huggingface/transformers/pull/25702 + from transformers.trainer_utils import ShardedDDPOption +except ImportError: + ShardedDDPOption = None + +if TYPE_CHECKING: + from .arguments import Seq2SeqTrainingArguments, TrainingArguments + from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer, + RewardTrainer, GRPOTrainer) + from .rlhf_arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig + from .trainer_factory import TrainerFactory + from .trainers import Seq2SeqTrainer, Trainer, EmbeddingTrainer + from .mixin import SwiftMixin + +else: + _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} + _import_structure = { + 'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'], + 'rlhf_arguments': + ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig'], + 'rlhf_trainer': [ + 'CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer', + 'GRPOTrainer' + ], + 'trainer_factory': ['TrainerFactory'], + 'trainers': ['Seq2SeqTrainer', 'Trainer', 'EmbeddingTrainer'], + 'mixin': ['SwiftMixin'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects=_extra_objects, + ) diff --git a/swift/trainers/__pycache__/__init__.cpython-310.pyc b/swift/trainers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5637bb132bf20f7b22f12cda9ddd136cdbbe2b3 Binary files /dev/null and b/swift/trainers/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/arguments.cpython-310.pyc b/swift/trainers/__pycache__/arguments.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5515df1b3774102c5ae4cb671c11044fd814ff Binary files /dev/null and b/swift/trainers/__pycache__/arguments.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/callback.cpython-310.pyc b/swift/trainers/__pycache__/callback.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31b08c2bcb73274c7babecbe114d997e3d946fba Binary files /dev/null and b/swift/trainers/__pycache__/callback.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/mixin.cpython-310.pyc b/swift/trainers/__pycache__/mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26626f516705e4d1cf6c2206e779e0a9110036a0 Binary files /dev/null and b/swift/trainers/__pycache__/mixin.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/rlhf_arguments.cpython-310.pyc b/swift/trainers/__pycache__/rlhf_arguments.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045d8106d600422279c7e20cc9cd63c36b795350 Binary files /dev/null and b/swift/trainers/__pycache__/rlhf_arguments.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc b/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fcf32f0f56201f2728ce19adae7888eca084afc Binary files /dev/null and b/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/trainers.cpython-310.pyc b/swift/trainers/__pycache__/trainers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e26366b0ed7c69c52f1809f7833a3218cbfdada Binary files /dev/null and b/swift/trainers/__pycache__/trainers.cpython-310.pyc differ diff --git a/swift/trainers/__pycache__/utils.cpython-310.pyc b/swift/trainers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4909d56844c1622ca1345b95d9a33f7766943433 Binary files /dev/null and b/swift/trainers/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..14c98b5c1a7a14b6cd361565e3382688aeeddcb1 --- /dev/null +++ b/swift/trainers/arguments.py @@ -0,0 +1,214 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +import platform +from dataclasses import dataclass, field +from functools import wraps +from typing import List, Literal, Optional, Union + +import torch +import torch.utils.checkpoint +from transformers.training_args import TrainingArguments as HfTrainingArguments +from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments + +from swift.utils import get_dist_setting, get_logger, is_liger_available, use_torchacc +from .optimizers.galore import GaLoreConfig + +logger = get_logger() + + +@dataclass +class TrainArgumentsMixin: + """ + check_model (bool): Flag to check the model is latest. Default is True. + acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'. + """ + per_device_train_batch_size: int = 1 + per_device_eval_batch_size: int = 1 + gradient_accumulation_steps: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None + logging_first_step: bool = True + logging_steps: int = 5 + + weight_decay: float = 0.1 + adam_beta2: float = 0.95 + lr_scheduler_type: str = 'cosine' + lr_scheduler_kwargs: Optional[Union[dict, str]] = None + report_to: List[str] = field(default_factory=lambda: ['tensorboard']) + dataloader_num_workers: Optional[int] = None + dataloader_prefetch_factor: Optional[int] = None + use_liger_kernel: bool = False + + # extra + check_model: bool = True + acc_strategy: Literal['token', 'seq'] = 'token' + train_dataloader_shuffle: bool = True + max_epochs: Optional[int] = None + + # torchacc + metric_warmup_step: Optional[float] = 0 + fsdp_num: int = 1 + acc_steps: int = 1 + + # train-eval loop args + eval_use_evalscope: bool = False + eval_datasets: List[str] = field(default_factory=list) + eval_limit: Optional[int] = None + eval_datasets_args: Optional[Union[str, dict]] = None + eval_generation_config: Optional[Union[str, dict]] = None + + def _fix_gradient_checkpointing(self): + # fix use_reentrant + if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching + return + # Consistent with the default behavior of transformers. + use_reentrant_ = ( + self.gradient_checkpointing_kwargs.get('use_reentrant', True) + if self.gradient_checkpointing_kwargs else True) + _old_checkpoint = torch.utils.checkpoint.checkpoint + + @wraps(_old_checkpoint) + def _new_checkpoint(*args, use_reentrant=None, **kwargs): + return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs) + + torch.utils.checkpoint._old_checkpoint = _old_checkpoint + torch.utils.checkpoint.checkpoint = _new_checkpoint + try: + # Fix the old version of transformers. + import transformers.modeling_utils + transformers.modeling_utils.checkpoint = _new_checkpoint + except (ImportError, AttributeError): + pass + + def _init_liger(self): + if self.use_liger_kernel: + assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`' + + def __post_init__(self): + from swift.llm.argument.base_args.model_args import ModelArguments + if use_torchacc(): + self.dataloader_drop_last = True + if self.gradient_accumulation_steps is None: + world_size = get_dist_setting()[2] + self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size)) + logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}') + if self.lr_scheduler_kwargs: + self.lr_scheduler_kwargs = ModelArguments.parse_to_dict(self.lr_scheduler_kwargs) + if self.gradient_checkpointing_kwargs: + self.gradient_checkpointing_kwargs = ModelArguments.parse_to_dict(self.gradient_checkpointing_kwargs) + self._fix_gradient_checkpointing() + self._init_liger() + if self.dataloader_num_workers is None: + if platform.system() == 'Windows': + self.dataloader_num_workers = 0 + else: + self.dataloader_num_workers = 1 + logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}') + if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0: + self.dataloader_prefetch_factor = 10 + if self.eval_use_evalscope: + try: + import evalscope + except ImportError: + raise ImportError('evalscope is not installed, please install it by `pip install evalscope`') + self.eval_datasets_args = ModelArguments.parse_to_dict(self.eval_datasets_args) + self.eval_generation_config = ModelArguments.parse_to_dict(self.eval_generation_config) + + super().__post_init__() + + +@dataclass +class SwiftArgumentsMixin(TrainArgumentsMixin): + # Value copied from TrainArguments + train_type: Optional[str] = None + optimizer: Optional[str] = None + local_repo_path: Optional[str] = None + galore_config: Optional[GaLoreConfig] = None + + def __post_init__(self): + if hasattr(self, 'output_dir'): + self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir)) + super().__post_init__() + + @property + def place_model_on_device(self): + return False if use_torchacc() else super().place_model_on_device + + +@dataclass +class GRPOArgumentsMixin: + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + top_k: int = 50 + top_p: float = 0.9 + repetition_penalty: float = 1. + num_infer_workers: int = 1 + # vllm + vllm_device: List[str] = field(default_factory=lambda: ['auto']) + vllm_gpu_memory_utilization: float = 0.9 + vllm_max_model_len: Optional[int] = None + vllm_max_num_seqs: int = 256 + vllm_enforce_eager: bool = False + vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' + vllm_enable_prefix_caching: bool = True + # reward function args, see details in swift/plugin/orm.py + # cosine reward, https://arxiv.org/abs/2502.03373 + cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. + cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length. + cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length. + cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length. + cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length + # repetition penalty, https://arxiv.org/abs/2502.03373 + repetition_n_grams: int = 3 + repetition_max_penalty: float = -1.0 + + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None + # LMDeploy in GRPO + use_lmdeploy: bool = False + lmdeploy_device: Optional[str] = 'auto' + lmdeploy_session_len: Optional[int] = None + lmdeploy_cache_max_entry_count: float = 0.8 + + async_generate: bool = False + tensor_parallel_size: int = 1 + sleep_level: int = 0 + move_model_batches: Optional[int] = None + offload_optimizer: bool = False + offload_model: bool = False + gc_collect_after_offload: bool = False + multi_turn_func: Optional[str] = None + + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + soft_max_length: Optional[int] = None + soft_cache_length: Optional[int] = None + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: bool = True + + # compatible with trl main branch(0.17.0.dev0) + wandb_log_unique_prompts: Optional[bool] = None + + # external vllm + vllm_server_host: Optional[str] = None + vllm_server_port: int = 8000 + vllm_server_timeout: float = 240.0 + vllm_client = None + + # dataset + dataset_shuffle: Optional[bool] = True + + +@dataclass +class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments): + pass + + +@dataclass +class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments): + pass diff --git a/swift/trainers/callback.py b/swift/trainers/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0343d88fb9e59ef7e91d4e50e3494e4652cb23 --- /dev/null +++ b/swift/trainers/callback.py @@ -0,0 +1,124 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +import time + +from tqdm import tqdm +from transformers import trainer +from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl, + TrainerState) +from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics + +from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc +from ..utils.utils import format_time +from .arguments import TrainingArguments + + +def add_train_message(logs, state, start_time) -> None: + logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}' + train_percentage = state.global_step / state.max_steps if state.max_steps else 0. + logs['percentage'] = f'{train_percentage * 100:.2f}%' + elapsed = time.time() - start_time + logs['elapsed_time'] = format_time(elapsed) + if train_percentage != 0: + logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed) + for k, v in logs.items(): + if isinstance(v, float): + logs[k] = round(logs[k], 8) + + +class ProgressCallbackNew(ProgressCallback): + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True) + self.current_step = 0 + self.start_time = time.time() + if use_torchacc(): + self.warmup_start_time = 0 + self.warmup_metric = None + self.metric_warmup_step = int(args.metric_warmup_step + * state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step + + def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs): + if state.is_world_process_zero and has_length(eval_dataloader): + if self.prediction_bar is None: + if self.training_bar is not None: + self.training_bar.fp.write('\n') + self.prediction_bar = tqdm( + desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0) + self.prediction_bar.update() + + def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): + + if use_torchacc(): + if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0: + self.warmup_start_time = time.time() + self.metric_warmup_step = state.global_step + if state.max_steps == state.global_step and self.warmup_metric is None: + num_steps = state.max_steps - self.metric_warmup_step + num_total_samples = args.train_dataset_sample + num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps) + self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples, + num_steps) + self.warmup_metric['num_total_samples'] = num_total_samples + self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples + if 'train_samples_per_second' in logs: + logs.update(self.warmup_metric) + state.log_history[-1] = logs + + add_train_message(logs, state, self.start_time) + if not is_pai_training_job() and state.is_world_process_zero: + jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, logs) + super().on_log(args, state, control, logs, **kwargs) + if state.is_world_process_zero and self.training_bar is not None: + self.training_bar.refresh() + + +class DefaultFlowCallbackNew(DefaultFlowCallback): + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + control = super().on_step_end(args, state, control, **kwargs) + # save the last ckpt + evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy + if state.global_step == state.max_steps: + if evaluation_strategy != IntervalStrategy.NO: + control.should_evaluate = True + if args.save_strategy != IntervalStrategy.NO: + control.should_save = True + return control + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + control = super().on_epoch_end(args, state, control, **kwargs) + evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy + if args.max_epochs is not None and args.max_epochs <= math.ceil(state.epoch): + if evaluation_strategy != IntervalStrategy.NO: + control.should_evaluate = True + if args.save_strategy != IntervalStrategy.NO: + control.should_save = True + control.should_training_stop = True + return control + + +class PrinterCallbackNew(PrinterCallback): + + def on_train_begin(self, args, state, control, **kwargs): + self.start_time = time.time() + return super().on_train_begin(args, state, control, **kwargs) + + def on_log(self, args, state, control, logs=None, **kwargs): + add_train_message(logs, state, self.start_time) + if not is_pai_training_job() and state.is_world_process_zero: + jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, logs) + + _ = logs.pop('total_flos', None) + if state.is_world_process_zero: + print(logs, flush=True) + + +# monkey patching +trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew +trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] +trainer.PrinterCallback = PrinterCallbackNew diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd382d99f394e16eb362ecb58da969eccef066c --- /dev/null +++ b/swift/trainers/mixin.py @@ -0,0 +1,516 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +import inspect +import os +import shutil +import time +from contextlib import contextmanager +from copy import copy +from functools import partial +from types import MethodType +from typing import Callable, Dict, List, Optional, Tuple, Union + +import safetensors +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from datasets import Dataset as HfDataset +from modelscope import check_local_model_is_latest +from packaging import version +from peft import PeftModel +from torch.nn import Module +from torch.utils.data import DataLoader +from transformers import PreTrainedModel +from transformers.data.data_collator import DataCollator +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import unwrap_model +from transformers.trainer import TrainerCallback +from transformers.trainer_utils import EvalPrediction, IntervalStrategy +from transformers.utils import is_torch_npu_available + +from swift.hub import get_hub +from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template +from swift.plugin import MeanMetric, compute_acc, extra_tuners +from swift.tuners import SwiftModel +from swift.utils import get_logger, is_mp_ddp, use_torchacc +from swift.utils.torchacc_utils import ta_trim_graph +from ..utils.torch_utils import get_device_count +from .arguments import TrainingArguments +from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model + +try: + from trl import AutoModelForCausalLMWithValueHead +except (ImportError, RuntimeError): + AutoModelForCausalLMWithValueHead = None + +logger = get_logger() + + +class SwiftMixin: + + def __init__(self, + model: Union[PreTrainedModel, Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[HfDataset] = None, + eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None, + template: Optional[Template] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + **kwargs) -> None: + if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1: + args.dataloader_num_workers = 1 + logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.') + + if args.check_model and hasattr(model, 'model_dir'): + from swift.utils.logger import ms_logger_ignore_error + with ms_logger_ignore_error(): + check_local_model_is_latest( + model.model_dir, user_agent={ + 'invoked_by': 'local_trainer', + 'third_party': 'swift', + }) + if eval_dataset is None and args: + args.evaluation_strategy = IntervalStrategy.NO + args.eval_strategy = IntervalStrategy.NO + + self._custom_metrics = {} + self.template = template + self.max_memory = 0 + self.hub = get_hub() + + self.model_meta = model.model_meta + with self.hub.patch_hub(): + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=template.tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **kwargs) + + self.compute_loss_func = compute_loss_func + if get_function(model.__class__.forward) is not get_function(model.forward): + self.label_names = find_labels(model) + self.can_return_loss = can_return_loss(model) + self.label_names = self.label_names or ['labels'] + self.start_time = time.time() + if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + sequence_parallel.prepare_trainer(self) + + def _save_initial_model(self, output_dir): + # pissa/olora/lora-ga + model = unwrap_model(self.model) + if isinstance(model, PeftModel): + config = model.peft_config.get('default') + init_lora_weights = getattr(config, 'init_lora_weights', None) + if (isinstance(init_lora_weights, str) + and any(s in init_lora_weights for s in ('pissa', 'olora', 'lora-ga'))): + config.init_lora_weights = True + model.save_pretrained(os.path.join(output_dir, 'initial_model')) + config.init_lora_weights = init_lora_weights + + def _save_converted_model(self, output_dir): + # pissa/olora/lora-ga + model = unwrap_model(self.model) + if isinstance(model, PeftModel): + config = model.peft_config.get('default') + init_lora_weights = getattr(config, 'init_lora_weights', None) + if isinstance(init_lora_weights, str): + config = copy(config) + os.makedirs(os.path.join(output_dir, 'converted'), exist_ok=True) + if 'lora-ga' in init_lora_weights: + try: + from lora_ga.entrypoint import LoraGAContext + with LoraGAContext(model): + model.save_pretrained( + os.path.join(output_dir, 'converted', 'default'), + path_initial_model_for_weight_conversion=os.path.join( + os.path.dirname(output_dir), 'initial_model'), + ) + model.peft_config['default'] = config + except ImportError as e: + error_message = """ + Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub. + Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'. + """ + logger.info(error_message) + raise RuntimeError(error_message) from e + elif 'pissa' in init_lora_weights or 'olora' in init_lora_weights: + model.save_pretrained( + os.path.join(output_dir, 'converted', 'default'), + path_initial_model_for_weight_conversion=os.path.join( + os.path.dirname(output_dir), 'initial_model'), + ) + model.peft_config['default'] = config + + def _load_optimizer_and_scheduler(self, *args, **kwargs): + super()._load_optimizer_and_scheduler(*args, **kwargs) + if is_mp_ddp(): + # fix mp+ddp adamw + for v in self.optimizer.state.values(): + if 'step' in v: + # not on the same device + device_set = set([t.device for t in v.values()]) - {v['step'].device, torch.device('cpu')} + if len(device_set) >= 1: + v['step'] = v['step'].to('cpu') + + def _save_model(self, output_dir: Optional[str] = None, state_dict=None): + # model + supported_classes = (SwiftModel, PreTrainedModel, PeftModel) + supported_names = ('SentenceTransformer') + if AutoModelForCausalLMWithValueHead is not None: + supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, ) + save_safetensors = self.args.save_safetensors + if not isinstance(self.model, supported_classes) and self.model.__class__.__name__ not in supported_names: + if state_dict is None: + state_dict = self.model.state_dict() + + _unwrap_model = unwrap_model(self.model) + if isinstance(_unwrap_model, supported_classes): + _unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors) + else: + logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') + if save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) + else: + torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) + elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead): + # save reward model + state_dict = self.model.state_dict() + decoder_state_dict, v_head_state_dict = {}, {} + for name, param in state_dict.items(): + if name.startswith('v_head.'): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param + self.model.pretrained_model.save_pretrained( + output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors) + if save_safetensors: + from safetensors.torch import save_file + save_file( + v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin')) + elif is_instance_of_ms_model(self.model): + PreTrainedModel.save_pretrained( + self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors) + elif self.args.train_type in extra_tuners: + extra_tuners[self.args.train_type].save_pretrained( + self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors) + else: + if self.model.__class__.__name__ != 'SentenceTransformer': + self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors) + else: + + @contextmanager + def save_context(): + save_pretrained = self.model[0].auto_model.save_pretrained + _state_dict = { + key[len('0.auto_model.'):] if 'auto_model' in key else key: value + for key, value in state_dict.items() + } + self.model[0].auto_model.save_pretrained = partial( + self.model[0].auto_model.save_pretrained, state_dict=_state_dict) + yield + self.model[0].auto_model.save_pretrained = save_pretrained + + with save_context(): + self.model.save_pretrained(output_dir, safe_serialization=save_safetensors) + # copy sentencetransformers files + from swift.utils import copy_files_by_pattern + copy_files_by_pattern(self.model.model_dir, output_dir, '*.py') + copy_files_by_pattern(self.model.model_dir, output_dir, '*.json') + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + """Compatible with swift and peft""" + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + self._save_model(output_dir, state_dict) + # training_args.bin + torch.save(self.args, os.path.join(output_dir, 'training_args.bin')) + self._save_converted_model(output_dir) + # args.json + args_path = os.path.join(os.path.dirname(output_dir), 'args.json') + if os.path.exists(args_path): + shutil.copy(args_path, os.path.join(output_dir, 'args.json')) + # predict.jsonl + predict_jsonl = os.path.join(os.path.dirname(output_dir), 'predict.jsonl') + if os.path.exists(predict_jsonl): + shutil.move(predict_jsonl, os.path.join(output_dir, 'predict.jsonl')) + + is_adapter = isinstance(self.model, (SwiftModel, PeftModel)) + # tokenizer + if not is_adapter: + from swift.llm import save_checkpoint + additional_saved_files = self.model_meta.additional_saved_files + save_checkpoint( + None, + self.template.processor, + output_dir, + model_dirs=[self.model.model_dir], + additional_saved_files=additional_saved_files) + if getattr(self.model, 'origin_generation_config', None): + self.model.origin_generation_config.save_pretrained(output_dir) + + def _fix_zero3_gather_all_parameters(self) -> None: + if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'): + parameters = inspect.signature(self.deepspeed._zero3_consolidated_16bit_state_dict).parameters + if 'exclude_frozen_parameters' in parameters: + + def _zero3_consolidated_16bit_state_dict(model, exclude_frozen_parameters=False): + unwrapped = unwrap_model(model) + exclude_frozen_parameters = False + if isinstance(unwrapped, SwiftModel) and unwrapped.has_additional_modules: + exclude_frozen_parameters = True + if isinstance(unwrapped, PeftModel): + exclude_frozen_parameters = True + return model._zero3_consolidated_16bit_state_dict_origin(exclude_frozen_parameters) + + self.deepspeed._zero3_consolidated_16bit_state_dict_origin = ( + self.deepspeed._zero3_consolidated_16bit_state_dict) + self.deepspeed._zero3_consolidated_16bit_state_dict = MethodType(_zero3_consolidated_16bit_state_dict, + self.deepspeed) + + def _save_checkpoint(self, *args, **kwargs): + self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}') + self._fix_zero3_gather_all_parameters() + result = super()._save_checkpoint(*args, **kwargs) + logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}') + return result + + @staticmethod + @contextmanager + def _fix_grad_norm_nan(): + from accelerate import Accelerator + origin_clip_grad_norm_ = Accelerator.clip_grad_norm_ + + def clip_grad_norm_(self, parameters, *args, **kwargs): + # If NaN occurs, ignore weight updates. + parameters = list(parameters) + grad_norm = origin_clip_grad_norm_(self, parameters, *args, **kwargs) + if isinstance(grad_norm, torch.Tensor) and grad_norm.isnan().item(): + for p in parameters: + p.grad = None + return grad_norm + + Accelerator.clip_grad_norm_ = clip_grad_norm_ + try: + yield + finally: + Accelerator.clip_grad_norm_ = origin_clip_grad_norm_ + + def train(self, *args, **kwargs): + if self.model_meta.is_multimodal: + models = [] + for model_name in ['model', 'ref_model', 'value_model']: + model = getattr(self, model_name, None) + if isinstance(model, nn.Module): + models.append(model) + + reward_model = getattr(self, 'reward_model', None) + if reward_model is not None: + if isinstance(reward_model, list): + models.extend([m for m in reward_model if isinstance(m, nn.Module)]) + elif isinstance(reward_model, nn.Module): + models.append(reward_model) + + models = list(set(models)) # Deduplicate + self.template.register_post_encode_hook(models) + logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}.') + self._save_initial_model(self.args.output_dir) + with self.hub.patch_hub(), self._fix_grad_norm_nan(): + res = super().train(*args, **kwargs) + self.template.remove_post_encode_hook() + return res + + def push_to_hub(self, *args, **kwargs): + with self.hub.patch_hub(): + return super().push_to_hub(*args, **kwargs) + + def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float: + if device is None: + mems = [torch.cuda.max_memory_reserved(device=device) for device in range(get_device_count())] + else: + mems = [torch.cuda.max_memory_reserved(device=device)] + mem = sum(mems) / 1024**3 + self.max_memory = max(self.max_memory, mem) + return mem + + def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + self.control.should_log = False + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + loss = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) + logs: Dict[str, float] = {'loss': loss} # loss first + + for k, metric in self._custom_metrics.items(): + value = metric.compute() + if len(value) == 1: + val = list(value.values())[0] + logs[k] = val + else: + for k_suffix, val in value.items(): + new_k = f'{k}_{k_suffix}' + logs[new_k] = val + metric.reset() + + if version.parse(transformers.__version__) >= version.parse('4.38'): + grad_norm = args[0] + if grad_norm is not None: + logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm + logs['learning_rate'] = self._get_learning_rate() + if not is_torch_npu_available(): + logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2) + + elapse_time = time.time() - self.start_time + logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6) + for k in list(logs.keys()): + if logs[k] is None: + logs.pop(k) + tr_loss -= tr_loss + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + self.log(logs) + + if self.args.eval_use_evalscope and self.control.should_evaluate: + self._evalscope_eval() + super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + if self.args.optimizer is not None: + from swift.plugin import optimizers_map + optimizer_callback = optimizers_map[self.args.optimizer] + self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset) + if self.optimizer is None: + self.create_optimizer() + if self.lr_scheduler is None: + self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) + else: + super().create_optimizer_and_scheduler(num_training_steps=num_training_steps) + + def _compute_acc(self, outputs, labels) -> None: + args = self.args + acc_steps = args.acc_steps + preds = outputs.logits.argmax(dim=-1) + if self.state.global_step % acc_steps == 0: + if use_torchacc(): + ta_trim_graph() + preds = preds.to('cpu') + labels = labels.to('cpu') + metrics = compute_acc( + preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder) + for k, v in metrics.items(): + if k not in self._custom_metrics: + self._custom_metrics[k] = MeanMetric(nan_value=None) + self._custom_metrics[k].update(v) + + @torch.no_grad() + def _evalscope_eval(self): + from ..llm.eval.utils import EvalModel + from evalscope import TaskConfig, run_task + from evalscope.constants import EvalType + + self.model.eval() + max_batch_size = self.args.per_device_eval_batch_size + custom_model = EvalModel( + self.model, self.template, max_batch_size=max_batch_size, model_name=f'model-step{self.state.global_step}') + task_config = TaskConfig( + model=custom_model, + eval_type=EvalType.CUSTOM, + datasets=self.args.eval_datasets, + dataset_args=self.args.eval_datasets_args, + limit=self.args.eval_limit, + work_dir=os.path.join(self.args.output_dir, 'eval'), + eval_batch_size=max_batch_size, + generation_config=self.args.eval_generation_config or {'max_tokens': 512}, + ) + # start evaluation + eval_report = run_task(task_config) + # convert to dict + eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()} + self.log(eval_dict) + + self.model.train() + return eval_dict + + def get_batch_samples(self, *args, **kwargs): + res = super().get_batch_samples(*args, **kwargs) + if self.template.sequence_parallel_size == 1: + return res + + batch_samples, num_items_in_batch = res + if num_items_in_batch is None: + num_items_in_batch = torch.tensor(0).to(args[2]) + from swift.trainers.sequence_parallel import sequence_parallel + dist.all_reduce(num_items_in_batch, dist.ReduceOp.SUM, sequence_parallel.sp_group) + return batch_samples, num_items_in_batch + + +class DataLoaderMixin: + + def get_train_dataloader(self): + dataloader = None + if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size) + if dataloader is None: + # Higher efficiency + if self.train_dataset is None: + raise ValueError('Trainer: training requires a train_dataset.') + args = self.args + train_dataset = self.train_dataset + + dataloader_params = { + 'collate_fn': self.data_collator, + 'num_workers': args.dataloader_num_workers, + 'pin_memory': args.dataloader_pin_memory, + 'persistent_workers': args.dataloader_persistent_workers, + 'prefetch_factor': args.dataloader_prefetch_factor + } + batch_sampler_params = { + 'drop_last': args.dataloader_drop_last, + 'shuffle': args.train_dataloader_shuffle, + 'data_seed': args.data_seed, + } + + if hasattr(train_dataset, '__len__'): + batch_sampler = BatchSamplerShard( + len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params) + dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params) + else: + # IterableDataset + if dist.is_initialized() and dataloader_params['prefetch_factor']: + dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size() + dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params) + dataloader = DataLoaderDispatcher(dataloader) + + return dataloader + + def get_eval_dataloader(self, eval_dataset=None): + dataloader = None + if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + if eval_dataset is None and self.eval_dataset is None: + raise ValueError('Trainer: evaluation requires an eval_dataset.') + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size) + if dataloader is None: + return super().get_eval_dataloader(eval_dataset=eval_dataset) + return dataloader diff --git a/swift/trainers/optimizers/__init__.py b/swift/trainers/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/swift/trainers/optimizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc b/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96eb047836e29ce7f83412244fe3f0a25a26e2f2 Binary files /dev/null and b/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/trainers/optimizers/galore/__init__.py b/swift/trainers/optimizers/galore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..822853cd8c7f8a585138c45fbc9e5a44f749efb5 --- /dev/null +++ b/swift/trainers/optimizers/galore/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .utils import create_optimizer_and_scheduler, GaLoreConfig + from .adafactor import GaLoreAdafactor + from .adamw8bit import GaLoreAdamW8bit + from .adamw import GaLoreAdamW +else: + _import_structure = { + 'utils': ['GaLoreConfig', 'create_optimizer_and_scheduler'], + 'adafactor': ['GaLoreAdafactor'], + 'adamw8bit': ['GaLoreAdamW8bit'], + 'adamw': ['GaLoreAdamW'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc b/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c5cb466f9ade41411d19e43a5eef94063dcd8e4 Binary files /dev/null and b/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc b/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b478ba503b1edb57cf46adfa203054a26b376830 Binary files /dev/null and b/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/trainers/optimizers/galore/adafactor.py b/swift/trainers/optimizers/galore/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..98ab26477ad4d53ad1dc7de19324794cf24ae001 --- /dev/null +++ b/swift/trainers/optimizers/galore/adafactor.py @@ -0,0 +1,272 @@ +# copy dependencies from transformers/optimization.py +# code borrowed from https://github.com/jiaweizzhao/GaLore +import math + +import torch +from torch.optim import Optimizer +from transformers.utils.versions import require_version + +from .galore_projector import GaLoreProjector + + +class Adafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + require_version('torch>=1.5.0') # add_ with alpha + if lr is not None and relative_step: + raise ValueError('Cannot combine manual `lr` and `relative_step=True` options') + if warmup_init and not relative_step: + raise ValueError('`warmup_init=True` requires `relative_step=True`') + + defaults = { + 'lr': lr, + 'eps': eps, + 'clip_threshold': clip_threshold, + 'decay_rate': decay_rate, + 'beta1': beta1, + 'weight_decay': weight_decay, + 'scale_parameter': scale_parameter, + 'relative_step': relative_step, + 'warmup_init': warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group['lr'] + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps'][1], param_state['RMS']) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel()**0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + + if 'step' not in state: + state['step'] = 0 + + # GaLore Projection + if 'rank' in group: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + proj_type=group['proj_type']) + + grad = state['projector'].project(grad, state['step']) + + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if 'RMS' not in state: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = (grad**2) + group['eps'][0] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group['beta1']).add_(update, alpha=(1 - group['beta1'])) + update = exp_avg + + # GaLore Projection Back + if 'rank' in group: + update = state['projector'].project_back(update) + + if group['weight_decay'] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group['weight_decay'] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss + + +GaLoreAdafactor = Adafactor diff --git a/swift/trainers/optimizers/galore/adamw.py b/swift/trainers/optimizers/galore/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..7396334a32d974a3631e30862a384f908a6816f4 --- /dev/null +++ b/swift/trainers/optimizers/galore/adamw.py @@ -0,0 +1,141 @@ +# copy dependencies from transformers/optimization.py +# code borrowed from https://github.com/jiaweizzhao/GaLore +import math +from typing import Callable, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer +from transformers.utils.versions import require_version + +from .galore_projector import GaLoreProjector + + +class AdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + no_deprecation_warning: bool = False, + ): + require_version('torch>=1.5.0') # add_ with alpha + if lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr} - should be >= 0.0') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps} - should be >= 0.0') + defaults = {'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, 'correct_bias': correct_bias} + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + if 'step' not in state: + state['step'] = 0 + + # GaLore Projection + if 'rank' in group: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + proj_type=group['proj_type']) + + grad = state['projector'].project(grad, state['step']) + + # State initialization + if 'exp_avg' not in state: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1**state['step'] + bias_correction2 = 1.0 - beta2**state['step'] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # compute norm gradient + norm_grad = exp_avg / denom + + # GaLore Projection Back + if 'rank' in group: + norm_grad = state['projector'].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group['weight_decay'] > 0.0: + p.add_(p, alpha=(-group['lr'] * group['weight_decay'])) + + return loss + + +GaLoreAdamW = AdamW diff --git a/swift/trainers/optimizers/galore/adamw8bit.py b/swift/trainers/optimizers/galore/adamw8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..66b0c5b621369ec16577729df5251848a8796e90 --- /dev/null +++ b/swift/trainers/optimizers/galore/adamw8bit.py @@ -0,0 +1,112 @@ +# code borrowed from https://github.com/jiaweizzhao/GaLore +import torch +from bitsandbytes.optim.optimizer import Optimizer2State + +from .galore_projector import GaLoreProjector + + +class AdamW8bit(Optimizer2State): + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False): + super().__init__( + 'adam', + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + # if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group['params']): + if p.grad is None: + continue + state = self.state[p] + + if 'step' not in state: + state['step'] = 0 + + # GaLore Projection + if 'rank' in group: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + proj_type=group['proj_type']) + + if 'weight_decay' in group and group['weight_decay'] > 0: + # ensure that the weight decay is not applied to the norm grad + group['weight_decay_saved'] = group['weight_decay'] + group['weight_decay'] = 0 + + grad = state['projector'].project(p.grad, state['step']) + + # suboptimal implementation + p.saved_data = p.data.clone() + p.data = grad.clone().to(p.data.dtype).to(p.data.device) + p.data.zero_() + p.grad = grad + + if 'state1' not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # GaLore Projection Back + if 'rank' in group: + p.data = p.saved_data.add_(state['projector'].project_back(p.data)) + + # apply weight decay + if 'weight_decay_saved' in group: + p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved']) + group['weight_decay'] = group['weight_decay_saved'] + del group['weight_decay_saved'] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss + + +GaLoreAdamW8bit = AdamW8bit diff --git a/swift/trainers/optimizers/galore/galore_projector.py b/swift/trainers/optimizers/galore/galore_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..52fa1f0f3a3abcb92cc029f29ce390a3760667cf --- /dev/null +++ b/swift/trainers/optimizers/galore/galore_projector.py @@ -0,0 +1,109 @@ +# code borrowed from https://github.com/jiaweizzhao/GaLore + +import torch + + +class GaLoreProjector: + + def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + + def project(self, full_rank_grad, iter): + + if self.proj_type == 'std': + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == 'reverse_std': + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == 'right': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == 'left': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == 'full': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') + low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if self.proj_type == 'std': + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == 'reverse_std': + if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + else: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == 'right': + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == 'left': + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == 'full': + full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] + + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + + # make the smaller matrix always to be orthogonal matrix + if type == 'right': + A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == 'left': + A = U[:, :rank] + B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == 'full': + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError('type should be left, right or full') diff --git a/swift/trainers/optimizers/galore/utils.py b/swift/trainers/optimizers/galore/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9f243f8cba23547e5a0147d9b236c13cf7dfdc --- /dev/null +++ b/swift/trainers/optimizers/galore/utils.py @@ -0,0 +1,214 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import importlib +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.optim import Optimizer +from transformers import Trainer, TrainingArguments, get_scheduler + +from swift.utils import get_logger + +try: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +except ImportError: + from torch.optim.lr_scheduler import LRScheduler + +logger = get_logger() + + +@dataclass +class GaLoreConfig: + """ + The configuration class for the Galore module. + + + See https://arxiv.org/abs/2403.03507 + + Args: + rank (`int`): The galore rank + target_modules (`Union[str, List[str]]`): The target modules to use, if `None`, + will use all attn and mlp linears + update_proj_gap(`int`): The projection update interval for galore + proj_type(`str`) The project type of Galore, valid values are `std`, + `reverse_std`, `right`, `left`, `full` + galore_scale(float): the scale of gradient + optim_per_parameter(bool): Gives one optimizer per parameter + """ + rank: int = 128 + target_modules: Union[str, List[str]] = None + update_proj_gap: int = 50 + galore_scale: float = 1.0 + proj_type: str = 'std' + optim_per_parameter: bool = False + quantize: bool = False + proj_quant: bool = False + proj_bits: int = 4 + proj_group_size: int = 256 + cos_threshold: float = 0.4 + gamma_proj: int = 2 + queue_size: int = 5 + + +class GaloreOptimizerWrapper(Optimizer): + + def __init__(self, optimizers: Dict[Any, Optimizer]): + self.optimizers = optimizers + super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.}) + + def zero_grad(self, *args, **kwargs) -> None: + for optim in self.optimizers.values(): + optim.zero_grad(*args, **kwargs) + + def step(self, *args, **kwargs) -> None: + for optim in self.optimizers.values(): + optim.step(*args, **kwargs) + + +class GaloreSchedulerWrapper(LRScheduler): + + def __init__(self, lr_schedulers: Dict[Any, LRScheduler]): + self.lr_schedulers = lr_schedulers + + def step(self, *args, **kwargs) -> None: + for lr_scheduler in self.lr_schedulers.values(): + lr_scheduler.step(*args, **kwargs) + self._last_lr = lr_scheduler.get_last_lr() + + +def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps, + **defaults): + galore_params = [] + for module_name, module in model.named_modules(): + if not isinstance(module, (nn.Linear, nn.Embedding)) or \ + not any(target_key in module_name for target_key in config.target_modules): + continue + + if not module.weight.requires_grad: + continue + + logger.info(f'Enable GaLore for weights in module: {module_name}') + galore_params.append(module.weight) + + id_galore_params = [id(p) for p in galore_params] + galore_defaults = { + 'rank': config.rank, + 'update_proj_gap': config.update_proj_gap, + 'scale': config.galore_scale, + 'proj_type': config.proj_type, + **defaults + } + if config.quantize: + galore_defaults['quant'] = config.proj_quant + galore_defaults['quant_n_bit'] = config.proj_bits + galore_defaults['quant_group_size'] = config.proj_group_size + galore_defaults['cos_threshold'] = config.cos_threshold + galore_defaults['gamma_proj'] = config.gamma_proj + galore_defaults['queue_size'] = config.queue_size + optim_cls, optim_kwargs = get_optimizer(args, config) + + if config.optim_per_parameter and not config.quantize: + # q-galore does not support optim_per_parameter + optimizer_dict = {} + galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2 + for p in model.parameters(): + if p.requires_grad: + if id(p) in id_galore_params: + optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs) + else: + optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs) + + # get scheduler dict + scheduler_dict = {} + for p in model.parameters(): + if p.requires_grad: + scheduler_dict[p] = get_scheduler( + optimizer=optimizer_dict[p], + name=args.lr_scheduler_type, + num_training_steps=max_steps * 2, + num_warmup_steps=args.warmup_steps * 2, + scheduler_specific_kwargs=args.lr_scheduler_kwargs, + ) + + return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict) + else: + decay_parameters = Trainer.get_decay_parameter_names(Trainer, model) + param_groups = [{ + 'params': galore_params, + **galore_defaults, + }] + param_groups.extend([ + { + 'params': [ + p for n, p in model.named_parameters() + if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad) + ], + 'weight_decay': + defaults['weight_decay'], + }, + { + 'params': [ + p for n, p in model.named_parameters() + if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad) + ], + 'weight_decay': + 0.0, + }, + ]) + optim = optim_cls(param_groups, **optim_kwargs) + scheduler = get_scheduler( + optimizer=optim, + name=args.lr_scheduler_type, + num_training_steps=max_steps, + num_warmup_steps=args.warmup_steps, + scheduler_specific_kwargs=args.lr_scheduler_kwargs, + ) + return optim, scheduler + + +def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]: + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(' ', '').split(','): + key, value = mapping.split('=') + optim_args[key] = value + + optimizer_kwargs = {'lr': args.learning_rate} + + adam_kwargs = { + 'betas': (args.adam_beta1, args.adam_beta2), + 'eps': args.adam_epsilon, + } + if args.optim == 'adafactor': + from .adafactor import GaLoreAdafactor + optimizer_cls = GaLoreAdafactor + optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False}) + elif args.optim in ('adamw_hf', 'adamw_torch'): + if config.quantize: + assert importlib.util.find_spec('q_galore_torch') is not None, \ + 'Please install q-galore by `pip install q_galore_torch`' + logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0') + from swift.utils import get_dist_setting + _, _, world_size, _ = get_dist_setting() + if world_size > 1: + # from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW + from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW + else: + from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW + else: + from .adamw import GaLoreAdamW + optimizer_cls = GaLoreAdamW + optimizer_kwargs.update(adam_kwargs) + elif 'adamw' in args.optim and '8bit' in args.optim: + try: + from .adamw8bit import GaLoreAdamW8bit + optimizer_cls = GaLoreAdamW8bit + optimizer_kwargs.update(adam_kwargs) + optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim}) + except ImportError: + raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!') + else: + raise ValueError(f'Galore not supported for optimizer type: {args.optim}') + return optimizer_cls, optimizer_kwargs diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..268bca7aad8cfca2e57a589db6ec60b9d3f8feef --- /dev/null +++ b/swift/trainers/rlhf_arguments.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass, field +from typing import List + +from trl import CPOConfig as HfCPOConfig +from trl import DPOConfig as HfDPOConfig +from trl import GRPOConfig as HfGRPOConfig +from trl import KTOConfig as HfKTOConfig +from trl import ORPOConfig as HfORPOConfig +from trl import PPOConfig as HfPPOConfig +from trl import RewardConfig as HfRewardConfig + +from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin + + +@dataclass +class DPOConfig(SwiftArgumentsMixin, HfDPOConfig): + pass + + +@dataclass +class CPOConfig(SwiftArgumentsMixin, HfCPOConfig): + pass + + +@dataclass +class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig): + pass + + +@dataclass +class KTOConfig(SwiftArgumentsMixin, HfKTOConfig): + pass + + +@dataclass +class RewardConfig(SwiftArgumentsMixin, HfRewardConfig): + pass + + +@dataclass +class PPOConfig(SwiftArgumentsMixin, HfPPOConfig): + pass + + +@dataclass +class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): + stop_words: List[str] = field(default_factory=list) + + def __post_init__(self): + from swift.llm.argument.base_args.model_args import ModelArguments + super().__post_init__() + if self.cosine_max_len is None: + self.cosine_max_len = self.max_completion_length + self.vllm_limit_mm_per_prompt = ModelArguments.parse_to_dict(self.vllm_limit_mm_per_prompt) + + if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][ + 'stage'] == 3: + # https://github.com/modelscope/ms-swift/issues/3237 + self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0 + self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0 + + # https://github.com/modelscope/ms-swift/issues/3863 + self.dataloader_drop_last = True diff --git a/swift/trainers/rlhf_trainer/.ipynb_checkpoints/grpo_trainer-checkpoint.py b/swift/trainers/rlhf_trainer/.ipynb_checkpoints/grpo_trainer-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a18db0f13fa9984c4b8ae4708f5a7f0a8321a063 --- /dev/null +++ b/swift/trainers/rlhf_trainer/.ipynb_checkpoints/grpo_trainer-checkpoint.py @@ -0,0 +1,1426 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/trl. +import concurrent.futures +import inspect +import os +import re +import time +from collections import defaultdict, deque +from concurrent.futures import Future +from contextlib import contextmanager +from copy import copy, deepcopy +from dataclasses import asdict, dataclass, field +from math import ceil +from queue import Queue +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import datasets +import numpy as np +import torch +import torch.nn as nn +import transformers +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from packaging import version +from torch.nn import ModuleList +from torch.utils.data import DataLoader +from transformers import PreTrainedModel, TrainerCallback +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.trainer import Trainer +from transformers.trainer_utils import seed_worker +from trl import GRPOTrainer as HFGRPOTrainer +from trl.extras.profiling import profiling_decorator +from trl.models import prepare_deepspeed +from trl.trainer.grpo_trainer import nanmax, nanmin + +from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device +from swift.llm.infer.infer_engine import set_device_context +from swift.llm.template.template_inputs import StdTemplateInputs +from swift.plugin import multi_turns, orms, rm_plugins +from swift.utils import (JsonlWriter, gc_collect, get_device, get_device_count, get_dist_setting, get_logger, + get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available) +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin +from .utils import patch_lora_merge, patch_lora_unmerge, round_robin + +del HFGRPOTrainer.__init__ +del HFGRPOTrainer.log + +logger = get_logger() +if is_wandb_available(): + import wandb + os.environ["WANDB_API_KEY"] = "a7ab128385681b17ad156ad0d8c81ba3e2296164" + os.environ["WANDB_MODE"] = "offline" + +InputsType = List[Dict[str, Union[torch.Tensor, Any]]] +OutputsType = List[List[Tuple[List[Dict], str]]] + + +@contextmanager +def unwrap_model_for_generation( + model, + accelerator, + gather_deepspeed3_params=True, + gather_parameters: List = None, +): + unwrapped_model = accelerator.unwrap_model(model) + if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) + else: + import deepspeed + parameters = [ + parameter for name, parameter in model.named_parameters() + if not gather_parameters or name in gather_parameters + ] + with deepspeed.zero.GatheredParameters(parameters): + from trl.models.utils import remove_hooks + remove_hooks(model) + yield accelerator.unwrap_model(model) + from trl.models.utils import add_hooks + add_hooks(model) + else: + yield unwrapped_model + + +class GRPOCallback(TrainerCallback): + + def __init__(self, trainer): + self.trainer = trainer + + # offload original_modules to cpu, to save memory + def on_train_begin(self, args, state, control, **kwargs): + self.trainer.queue = self.trainer.train_queue + train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader') + self.trainer._prefetch(train_dataloader) + + +@dataclass +class DataCache: + inputs: List[Dict] = field(default_factory=list) + outputs: List[Dict] = field(default_factory=list) + distributed_idx: List[List] = field(default_factory=list) + + +class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer): + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + reward_model: Optional[List[Union[PreTrainedModel, nn.Module]]] = None, + reward_funcs: Optional[List[Union[str, Callable]]] = None, + *_args, + **kwargs): + from swift.trainers.rlhf_arguments import GRPOConfig + args: GRPOConfig = kwargs['args'] + self.args = args + self.train_queue = Queue() + self.eval_queue = Queue() + self.processing_class = kwargs.get('template').tokenizer + self.offload_modules = {} + self.offload_states = {} + _, _, _, local_world_size = get_dist_setting() + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin') + + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + self.reward_model_plugins = [None] * len(self.reward_funcs) + + if reward_model is not None: + reward_template = kwargs.pop('reward_template') + reward_plugins = args.reward_model_plugin + if reward_plugins is None: + reward_plugins = ['default'] * len(reward_model) + assert len(reward_plugins) == len(reward_model), ( + f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match " + f"the number of 'reward_model' ({len(reward_model)}). " + "Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.") + for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template): + # Set encoding mode train(see details in Template.encode). + # Set max_length to None to disable truncation, as the input length has already been truncated earlier. + rm_template.set_mode('train') + rm_template.max_length = None + if rm_plugin not in rm_plugins: + raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin') + self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template)) + self.reward_funcs.append(rm) + self.reward_func_names.append(rm.config._name_or_path.split('/')[-1]) + + if not self.reward_funcs: + raise ValueError('You must specify reward_funcs or reward_model') + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + self.multi_turn_func = None + if self.args.multi_turn_func: + if isinstance(self.args.multi_turn_func, str): + assert self.args.multi_turn_func in multi_turns + multi_turn_func = multi_turns[self.args.multi_turn_func] + self.multi_turn_func = multi_turn_func + else: + self.multi_turn_func = self.args.multi_turn_func + + self.num_generations = args.num_generations + self.temperature = args.temperature + self.loss_type = args.loss_type + model.warnings_issued['estimate_tokens'] = True + kwargs['data_collator'] = lambda features: features + self.shuffle_dataset = args.dataset_shuffle + + use_vllm = args.use_vllm + use_lmdeploy = args.use_lmdeploy + vllm_client = kwargs.pop('vllm_client') # for external vllm + if self.args.tensor_parallel_size > 1 and self.multi_turn_func: + import torch.distributed as dist + rank, _, _, _ = get_dist_setting() + for tp_group in self.tp_group_ranks(): + group = dist.new_group(tp_group) + if rank in tp_group: + self.group = group + + super().__init__(model, ref_model, *_args, **kwargs) + + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps + self._textual_logs = { + 'prompt': deque(maxlen=maxlen), + 'completion': deque(maxlen=maxlen), + 'rewards': defaultdict(lambda: deque(maxlen=maxlen)), + } + + num_processes = self.accelerator.num_processes + self.effective_train_batch_size = effective_batch_size = \ + args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps + possible_values = [n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0] + + if self.num_generations not in possible_values: + raise ValueError( + f'The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x ' + f'{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per ' + f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for ' + f'the number of generations are: {possible_values}.') + if self.args.eval_strategy != 'no': + effective_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [ + n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0 + ] + if self.num_generations not in possible_values: + raise ValueError( + f'The effective eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be ' + f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the ' + 'current effective eval batch size, the valid values for the number of generations are: ' + f'{possible_values}.') + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + self.infer_device = None + self.use_fast_infer = use_vllm or use_lmdeploy # whether to use the PT backend + self.is_external_vllm = use_vllm and args.vllm_server_host is not None + if self.use_fast_infer: + if self.infer_rank >= 0: + fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device + if fast_infer_device[0] == 'auto': + if get_device_count() == 1: + fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it + else: + fast_infer_device = [] + for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()): + fast_infer_device.append(get_device(idx)) + + for _device in fast_infer_device: + # Check that the requested device is available + if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count(): + raise ValueError(f'The requested device for vllm ({_device}) is not available. ' + f'You are likely using vLLM ' + 'without restricting the number of GPUs for training. ' + 'Set the `--num_processes` argument to a ' + 'value lower than the number of GPUs available on your machine—typically, ' + 'reducing it by one is sufficient. ' + f'In your case: `--num_processes {get_device_count() - 1}`.') + + if use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.is_external_vllm: + self.vllm_client = vllm_client + else: + self.engine = self.prepare_vllm(model, fast_infer_device) + self.infer_device = fast_infer_device[self.local_infer_rank] + elif use_lmdeploy: + if not is_lmdeploy_available(): + raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.' + 'Please install LMDeploy with `pip install lmdeploy -U` to use it.') + from swift.llm import LmdeployEngine + from swift.tuners import Swift + with Swift.grpo_context(model, self.template.processor): + fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1]) + self.engine = LmdeployEngine( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + devices=[fast_infer_device], + session_len=args.lmdeploy_session_len, + cache_max_entry_count=args.lmdeploy_cache_max_entry_count, + reload_weights=True) + self.infer_device = fast_infer_device + from lmdeploy.turbomind.turbomind import TurboMind + lmdeploy_engine = self.engine.engine.engine + assert isinstance(lmdeploy_engine, TurboMind), ( + "Currently only LMDeploy's TurboMind backend is supported. " + 'The current model is incompatible - please use vLLM or PyTorch backend instead.') + if not self.is_external_vllm: + self.engine.default_template = copy(self.template) # Avoid thread-unsafe modifications of the mode. + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit + # Avoid thread-unsafe modifications of the mode. + self.request_config = RequestConfig( + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + ) + + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + self.request_config.n = self.args.tensor_parallel_size + if self.infer_rank >= 0: + self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size + + self.model_accepts_loss_kwargs = False + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + + # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + if self.args.async_generate: + self.add_callback(GRPOCallback(self)) + + if self.args.dynamic_sample: + self.resample_dataset = deepcopy(self.train_dataset) + + def cyclic_iter(iterable): + while True: + for x in iterable: + yield x + + self.resample_iterator = cyclic_iter(self.get_resample_dataloader()) + # flag indicating whether the evaluation has started + self.eval_flag = False + + @profiling_decorator + def _prepare_inputs( + self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + mode = 'train' if self.model.training else 'eval' + if mode == 'train': + generate_every = self.args.gradient_accumulation_steps * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch) + self._buffered_inputs = accumulated_local_batch # < this is the change + inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] + self._step += 1 + else: + inputs = self._generate_and_score_completions(accumulated_local_batch) + return inputs + + def split_batches(self): + """Sync weights in batches + Only split LLM layers for now: + 1. N batches for layers + 2. other, embeds, lm_heads in one batch + 3. multi-modal components in one batch + """ + model = self.accelerator.unwrap_model(self.model) + if self.args.move_model_batches is None: + # All in one + return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None] + + model_arch = get_model_arch(model.model_meta.model_arch) + non_llm_parameters = [] + llm_embeds = [] + parameters = [] + pattern = r'\.(\d+)\.' + + layer_count = None + # Get the number of layers in LLM modules + for name, module in model.named_modules(): + if isinstance(module, ModuleList): + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + continue + if isinstance(llm, list): + llm = llm[0] + if name.startswith('base_model'): + name = name.replace('base_model.', '') + if llm in name: + layer_count = len(module) + else: + layer_count = len(module) + assert layer_count is not None, 'Cannot find ModuleList to split modules.' + + n_layers = ceil(layer_count / self.args.move_model_batches) + for _ in range(self.args.move_model_batches): + parameters.append([]) + + def replace_lora(name): + if 'lora_' in name: + return '' + else: + return name.replace('base_layer.', '') + + def remove_lora_and_prefix(names): + names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) + return [n for n in names if n] + + def split_llm(name): + match = re.search(pattern, name) + if match: + number = match.group(1) + group = int(number) // n_layers + parameters[group].append(name) + else: + llm_embeds.append(name) + + for name, parameter in model.named_parameters(): + if 'ref_model' in name: + continue + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + non_llm_parameters.append(name) + elif isinstance(llm, list): + llm = llm[0] + if llm in name: + split_llm(name) + else: + non_llm_parameters.append(name) + else: + split_llm(name) + + if llm_embeds: + parameters.append(llm_embeds) + if non_llm_parameters: + parameters.append(non_llm_parameters) + parameters = [p for p in parameters if p] + parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters] + return parameters, parameters_no_lora + + def prepare_vllm(self, model, fast_infer_device): + from swift.tuners import Swift + from swift.llm import VllmEngine + from swift.llm.infer.infer_engine import GRPOVllmEngine + _, _, _, local_world_size = get_dist_setting() + if self.args.tensor_parallel_size > 1: + vllm_kwargs = {'distributed_executor_backend': 'external_launcher'} + else: + vllm_kwargs = {} + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + # Compatibility with TP + cls = GRPOVllmEngine + engine_kwargs = {'seed': 0} + else: + cls = VllmEngine + engine_kwargs = {} + with Swift.grpo_context(model, self.template.processor): + engine = cls( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + device=fast_infer_device[self.local_infer_rank], + tensor_parallel_size=self.args.tensor_parallel_size, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=self.args.vllm_max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + num_infer_workers=self.args.num_infer_workers, + enable_sleep_mode=self.args.sleep_level > 0, + use_async_engine=False, + max_model_len=self.args.vllm_max_model_len, + engine_kwargs=engine_kwargs, + **vllm_kwargs) + engine.default_template = self.template + return engine + + @property + def infer_rank(self): + if self.is_external_vllm: + # When using external vLLM, only the main process (rank=0) acts as the client. + return 0 if self.accelerator.is_main_process else -1 + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return node_rank * self.args.num_infer_workers + _vllm_rank + if local_rank == -1: + return 0 + return -1 + + @property + def infer_rank_tp_0(self): + # whether is tp rank0, get data from this rank + # vllm needs all tp ranks inputs and sampling params are the same + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank and _vllm_rank % self.args.tensor_parallel_size == 0: + return (node_rank * self.args.num_infer_workers + _vllm_rank // self.args.tensor_parallel_size) + if local_rank == -1: + return 0 + return -1 + + @property + def local_infer_rank(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return _vllm_rank + + return -1 + + def tp_group_ranks(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + return [ + list(range(0, world_size))[i:i + self.args.tensor_parallel_size] + for i in range(0, world_size, self.args.tensor_parallel_size) + ] + + @contextmanager + def _template_context(self, template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + mode = template.mode + if mode in {'vllm', 'pt', 'lmdeploy'}: + template.set_mode('train') + template.max_length = None + loss_scale = template.loss_scale + if self.multi_turn_func: + template.loss_scale = 'default' + try: + yield + finally: + template.loss_scale = loss_scale + template.set_mode(mode) + template.max_length = max_length + + @profiling_decorator + def _move_model_to_vllm_lmdeploy(self): + if self.is_external_vllm: + return super()._move_model_to_vllm() + + from accelerate.utils.other import is_compiled_module + + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + gather_parameters=parameter_group) as unwrapped_model: + + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + with patch_lora_merge(unwrapped_model, parameter_group): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + else: + state_dict = unwrapped_model.state_dict() + if parameter_group_no_lora: + parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + assert len(state_dict) > 0 and all([state.shape != torch.Size([0]) for state in state_dict.values()]) + if self.infer_rank >= 0: + if self.args.async_generate: + self._wait_queue() + if self.args.use_vllm: + llm_model = self.engine.inner_model + else: + llm_model = self.engine.engine.engine + llm_model.load_weights(state_dict.items()) + del state_dict + gc_collect() + # Unmerge the adapter to restore the model to its original state. + # This must be done after loading weights to ensure they correspond to the merged state. + if is_peft_model(unwrapped_model): + with patch_lora_unmerge(unwrapped_model): + unwrapped_model.unmerge_adapter() + + if self.infer_rank >= 0 and self.args.use_vllm and self.args.vllm_enable_prefix_caching: + self.engine.engine.reset_prefix_cache() + + def _wait_queue(self): + while self._queue.empty(): + time.sleep(0.01) + + @staticmethod + def reorder_outputs(outputs, distributed_idx): + index_to_output = {} + current_position = 0 + for output_idx in distributed_idx: + for idx in output_idx: + index_to_output[idx] = outputs[current_position] + current_position += 1 + + return [index_to_output[idx] for idx in sorted(index_to_output.keys())] + + def _infer_multi_turn(self, inputs_slice: np.ndarray, request_config: RequestConfig) -> Union[OutputsType, List]: + """Perform multi-turn or single-turn inference with support for tensor parallelism. + + Args: + inputs_slice: Array of input requests + request_config: Inference configuration parameters + + Returns: + List of outputs where each entry contains: + - List of responses per prompt (length = tensor_parallel_size) + - Each response is a tuple of (message_history, finish_reason) + """ + from swift.llm.infer.protocol import ChatCompletionResponse + rank, _, _, _ = get_dist_setting() + request_config = copy(request_config) + results: List[ChatCompletionResponse] = self._engine_infer( + infer_requests=inputs_slice, request_config=request_config, use_tqdm=False) + prompt_lens = len(inputs_slice) + messages_list = [None] * (len(inputs_slice) * self.args.tensor_parallel_size) + if self.multi_turn_func: + remove_response = True + while len(inputs_slice) > 0: + request_config.n = 1 + if self.infer_rank_tp_0 >= 0 or not self.use_fast_infer: + inputs = [] + cnt = 0 + for i, output in enumerate(results): + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + if remove_response or _input['messages'][-1]['role'] != 'assistant' or not \ + _input['messages'][-1]['content']: + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + else: + _input['messages'][-1]['content'] += choice.message.content + if 'index' not in _input: + _input['index'] = cnt + _input['finish_reason'] = choice.finish_reason + cnt += 1 + inputs.append(_input) + results: List[Dict] = self.multi_turn_func(inputs) # noqa + else: + length = sum([len(results[i].choices) for i in range(len(results))]) + results = [None] * length + + if self.args.tensor_parallel_size > 1: + # avoid duplicate calling in the same tensor parallel group + import torch.distributed as dist + if 'group_src' in inspect.signature(dist.broadcast_object_list).parameters: + dist.broadcast_object_list(results, group_src=0, group=self.group) + else: + global_src = dist.get_global_rank(self.group, 0) + dist.broadcast_object_list(results, src=global_src, group=self.group) + inputs_slice = [r for r in results if not r['finished']] + for idx, r in enumerate(results): + if r['finished'] or r['finish_reason'] == 'length': + messages_list[r['index']] = (r['messages'], r['finish_reason']) + if len(inputs_slice) > 0: + _input_std = [] + for _input in inputs_slice: + _input_std.append(StdTemplateInputs.from_dict(_input)) + # StdTemplateInputs will not remove responses in infer + results = self._engine_infer( + infer_requests=_input_std, request_config=request_config, use_tqdm=False) + # concat responses from the second loop + remove_response = False + + outputs = [] + assert not any([m is None for m in messages_list]) + for i in range(0, len(messages_list), self.args.tensor_parallel_size): + # reformat to [[x, x, x, x] [x, x, x, x]] + # this is the same format of sampling_params.n > 1 + outputs.append(messages_list[i:i + self.args.tensor_parallel_size]) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + else: + # single turn + outputs = [] + for i, output in enumerate(results): + _choices = [] + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + _choices.append((_input['messages'], choice.finish_reason)) + outputs.append(_choices) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + + if self.args.tensor_parallel_size > 1: + if self.infer_rank_tp_0 < 0: + outputs = [] + else: + _outputs = [] + for tp_idx in range(self.args.tensor_parallel_size): + for prompt_idx in range(len(outputs)): + _outputs.append(outputs[prompt_idx][tp_idx]) + outputs = [_outputs] + + return outputs + + def async_infer(self, inputs, inputs_slice, distributed_idx): + + def infer_task(): + with set_device_context(self.infer_device), self.multi_turn_completion_length_context(): + return self._infer_multi_turn(inputs_slice, self.request_config) + + future: Future = self.executor.submit(infer_task) + # pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling + current_queue = self._queue + + def done(_self): + current_queue.put(DataCache(inputs, _self.result(), distributed_idx)) + + future.add_done_callback(done) + + def _prefetch(self, dataloader: DataLoader): + inputs = next(iter(dataloader)) + all_inputs = gather_object(inputs) + nnodes = get_node_setting()[1] + distributed_idx = round_robin(len(all_inputs), nnodes * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + self._queue.put(DataCache(inputs, outputs, distributed_idx)) + else: + self._queue.put(DataCache(inputs, [], distributed_idx)) + if self.accelerator.num_processes > 1: + self.accelerator.wait_for_everyone() + + def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: + """ + This function performs fast inference by managing model and optimizer offloading, + loading weights if necessary, distributing inputs among workers, and generating + completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous + inference modes. + inputs: local inputs + """ + + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + if self.args.offload_model: + self.offload_model() + if self.args.offload_optimizer: + self.offload_optimizer() + if self.args.gc_collect_after_offload: + gc_collect() + # Skip the first wake_up to avoid the warning "Executor is not sleeping" + if self.engine.inner_model_executor.is_sleeping: + self.engine.engine.wake_up() + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy() + self._last_loaded_step = self.state.global_step + all_inputs = gather_object(inputs) + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + # Distribute inputs to different workers + # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker + # 1/3/5 dispatch to the second worker + # trying to shuffle and average the length + nnodes = get_node_setting()[1] + num_workers = 1 if self.is_external_vllm else nnodes + distributed_idx = round_robin(len(all_inputs), num_workers * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + if self.args.async_generate: + self.async_infer(inputs, _input_slice, distributed_idx) + data_cache = self._queue.get() + inputs = data_cache.inputs + outputs = data_cache.outputs + distributed_idx = data_cache.distributed_idx + else: + with set_device_context(self.infer_device): + request_config = copy(self.request_config) + if self.args.tensor_parallel_size > 1: + request_config.seed += self.state.global_step + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + else: + if self.args.async_generate: + # using old model to generate, which will ignore the `clip` of advantages. + self._queue.put(DataCache(inputs, [], distributed_idx)) + data_cache = self._queue.get() + inputs = data_cache.inputs + distributed_idx = data_cache.distributed_idx + outputs = [] + outputs = gather_object(outputs) + if self.args.tensor_parallel_size > 1: + outputs = [[item] for output in outputs for item in output] + if not self.is_external_vllm: + outputs = self.reorder_outputs(outputs, distributed_idx) + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + self.engine.engine.sleep(level=self.args.sleep_level) + if self.args.gc_collect_after_offload: + gc_collect() + if self.args.offload_model: + self.load_model() + if self.args.offload_optimizer: + self.load_optimizer() + return inputs, outputs + + def _generate_completions(self, inputs: InputsType) -> InputsType: + """Generate completions for given inputs using either fast inference or standard PyTorch inference. + + Args: + inputs: List of input examples containing conversation messages. + + Returns: + Modified inputs with generated completions added to the last message + and truncation flag set in 'is_truncated' field. + """ + mode = 'train' if self.model.training else 'eval' + if self.use_fast_infer: + inputs, outputs = self._fast_infer(inputs) + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + outputs = outputs[process_slice] + else: + # pt infer + is_multimodal = self.model.model_meta.is_multimodal + if is_multimodal: + models = self.template.remove_post_encode_hook() + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ), self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(inputs, self.request_config) + if mode == 'train': + # In training mode, ensure the model is returned to train() mode after inference + # This is necessary as pt engines set the model to eval mode during generation + self.model.train() + if is_multimodal: + self.template.register_post_encode_hook(models) + if isinstance(outputs[0][0], list): + outputs = [output[0] for output in outputs] + + for i, output in enumerate(outputs): + inputs[i]['messages'] = output[0][0] + inputs[i]['is_truncated'] = output[0][1] == 'length' + + return inputs + + def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + + inputs = self._generate_completions(inputs) + total_rewards_per_func, total_rewards, completions = self._score_completions(inputs) + mode = 'train' if self.model.training else 'eval' + + if self.args.dynamic_sample and mode == 'train': + # dynamic sampling for std=0 groups + inputs, total_rewards, total_rewards_per_func, completions = \ + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) + + # Prepare final outputs with advantages and other required fields + batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards) + # Log metrics + messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))] + + self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) + + return batch_encoded_inputs + + def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + """Score completions using all reward functions + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + Tuple containing: + - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards + - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards + - completions: List of generated completion strings + """ + device = self.accelerator.device + completions = [example['messages'][-1]['content'] for example in inputs] + rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) + + for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)): + # reward model + if isinstance(reward_func, nn.Module): + rewards_per_func[:, i] = reward_model_plugin(inputs=inputs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs = RowPreprocessor.rows_to_batched(inputs) + output_reward_func = reward_func(completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + total_rewards_per_func = gather(rewards_per_func) + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + + return total_rewards_per_func, total_rewards, completions + + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): + # DAPO https://arxiv.org/abs/2503.14476 + # Replaces samples with zero-reward-variance groups (std=0) + resample_count = 0 + valid_samples = [] + valid_rewards = [] + valid_rewards_per_func = [] + valid_completions = [] + + origin_data = (inputs, rewards, rewards_per_func, completions) + + while resample_count < self.args.max_resample_times: + grouped_rewards = rewards.view(-1, self.num_generations) + group_std = grouped_rewards.std(dim=1) + + valid_mask = (group_std > 0).repeat_interleave(self.num_generations) + all_inputs = gather_object(inputs) + valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask]) + valid_rewards.append(rewards[valid_mask]) + valid_rewards_per_func.append(rewards_per_func[valid_mask]) + valid_completions.extend( + [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask]) + + if len(valid_samples) >= self.effective_train_batch_size: + break + + inputs = next(self.resample_iterator) + inputs = Trainer._prepare_inputs(self, inputs) + inputs = self._generate_completions(inputs) + rewards_per_func, rewards, completions = self._score_completions(inputs) + resample_count += 1 + + if len(valid_samples) >= self.effective_train_batch_size: + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + inputs = valid_samples[:self.effective_train_batch_size][process_slice] + rewards = torch.cat(valid_rewards)[:self.effective_train_batch_size] + rewards_per_func = torch.cat(valid_rewards_per_func)[:self.effective_train_batch_size] + completions = valid_completions[:self.effective_train_batch_size][process_slice] + else: + logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.') + inputs, rewards, rewards_per_func, completions = origin_data + + return inputs, rewards, rewards_per_func, completions + + def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]: + """ + Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + + Args: + inputs (InputsType): List of input samples. Original shape is [gas*bs] where: + - gas: gradient accumulation steps + - bs: per-device batch size + rewards (torch.Tensor): Tensor of rewards corresponding to the inputs. + Shape should match the total number of samples (gas*bs*num_generations) + + Returns: + List[InputsType]: A list of prepared batch inputs, organized as [gas][bs] + """ + # Compute advantages + grouped_rewards = rewards.view(-1, self.num_generations) + mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) + if self.args.scale_rewards: + advantages /= (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + advantages = advantages[process_slice] + + mode = 'train' if self.model.training else 'eval' + bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + gas = self.args.gradient_accumulation_steps if mode == 'train' else 1 + + assert len(inputs) == bs * gas, f'Expected {bs * gas} inputs, got {len(inputs)}' + gas_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(gas)] + + ga_batch_encoded_inputs = [] + template = self.template + + # Split advantages by GAS chunks + advantage_chunks = torch.chunk(advantages, gas) + + for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): + # Encode and process each batch (size=bs) + with self._template_context(template): + batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch] + batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) + + # Process labels and masks + labels = batch_encoded_inputs.pop('labels') + logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs.update({ + 'completion_mask': + labels[:, -logits_to_keep:] != -100, + 'truncated_mask': + torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool), + 'logits_to_keep': + logits_to_keep, + 'advantages': + batch_advantages + }) + + with torch.no_grad(): + batch_encoded_inputs['old_per_token_logps'] = ( + self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None) + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs) + batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps + + ga_batch_encoded_inputs.append(batch_encoded_inputs) + + return ga_batch_encoded_inputs + + def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func): + """Log training/evaluation metrics""" + mode = 'train' if self.model.training else 'eval' + device = self.accelerator.device + + # Calculate completion length metrics + agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs])) + + self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) + self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) + self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) + # Calculate clip ratio + agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device)) + + term_completion_mask = agg_completion_mask[agg_truncated_mask] + clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) + + self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) + + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = rewards_per_func[:, i].mean().item() + self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards) + std_rewards = rewards_per_func[:, i].std().item() + self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) + + # Log overall reward stats + grouped_rewards = rewards.view(-1, self.num_generations) + self._metrics[mode]['reward'].append(grouped_rewards.mean().item()) + self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item()) + + # Log prompt and completion texts + self._textual_logs['prompt'].extend(gather_object(messages)) + self._textual_logs['completion'].extend(gather_object(completions)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training + if isinstance(inputs, list): + assert len(inputs) == 1 + inputs = inputs[0] + completion_mask = inputs['completion_mask'] + truncated_mask = inputs['truncated_mask'] + # apply the completion_mask to exclude loss and metrics for overlong completions + if self.args.overlong_filter and any(truncated_mask): + if all(truncated_mask): + logger.info('All completions are overlong, loss and KL will be zero') + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) + completion_mask = completion_mask * (~truncated_mask) + + per_token_logps = self._get_per_token_logps(model, inputs) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs['ref_per_token_logps'] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + advantages = inputs['advantages'] + old_per_token_logps = inputs['old_per_token_logps'] if self.old_policy else per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + # Log the metrics + mode = 'train' if self.model.training else 'eval' + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum() + high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum() + clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum() + + gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) + self._metrics[mode]['clip_ratio/low_mean'].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/low_min'].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) + self._metrics[mode]['clip_ratio/high_mean'].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/high_max'].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) + self._metrics[mode]['clip_ratio/region_mean'].append(gathered_clip_ratio.nanmean().item()) + + return loss + + # Get the per-token log probabilities for the completions for the model and the reference model + @profiling_decorator + def _get_per_token_logps(self, model, inputs): + from trl.trainer.utils import selective_log_softmax + logits_to_keep = inputs['logits_to_keep'] + input_ids = inputs['input_ids'] + unwrapped_model = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_model): + parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters + else: + parameters = inspect.signature(unwrapped_model.forward).parameters + if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters: + # save memory + return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } + with self._template_context(self.template): + logits = model(**inputs).logits + # exclude the last logit: it corresponds to the next token pred + logits = logits[:, -(logits_to_keep + 1):-1, :] + logits = logits / self.temperature + input_ids = input_ids[:, -logits_to_keep:] + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + def evaluation_loop(self, dataloader, *args, **kwargs): + # Wait for the training rollout to complete + if self.args.async_generate: + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + if self._queue.empty() and self.args.async_generate: + self._prefetch(dataloader) + metric_key_prefix = kwargs['metric_key_prefix'] + output = super().evaluation_loop(dataloader, *args, **kwargs) + metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()} + output.metrics.update(metrics) + self.eval_flag = True + return output + + def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor: + if self.args.async_generate: + # Wait for the eval rollout to complete + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + return super().training_step(model, inputs, num_items_in_batch) + + def _engine_infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + *, + use_tqdm: Optional[bool] = None, + ): + if self.is_external_vllm: + self._process_infer_requests_images(infer_requests) + return self.vllm_client.infer(infer_requests.tolist(), asdict(request_config), use_tqdm=use_tqdm) + else: + return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + + def _process_infer_requests_images(self, infer_requests: List[InferRequest]): + import base64 + if not any('images' in request for request in infer_requests): + return + for request in infer_requests: + if 'images' not in request: + continue + for i, img in enumerate(request['images']): + if 'bytes' in img and img['bytes']: + request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8') + return + + @property + def old_policy(self): + return self.num_iterations > 1 + + @property + def _queue(self): + if self.control.should_evaluate: + return self.eval_queue + else: + return self.train_queue + + @torch.no_grad() + def offload_model(self): + if len(self.offload_modules) > 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, module in unwrapped_model.named_modules(): + if isinstance(module, torch.nn.Embedding): + self.offload_modules[name] = module.weight.device + module.to('cpu') + elif not hasattr(module, 'device'): + pass + elif module.device.type != 'cpu': + self.offload_modules[name] = module.device + module.to('cpu') + + @torch.no_grad() + def load_model(self): + if len(self.offload_modules) == 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, device in self.offload_modules.items(): + module = unwrapped_model.get_submodule(name) + if isinstance(module, torch.nn.Embedding): + module.weight.to(device) + else: + module.to(device) + self.offload_modules.clear() + + @torch.no_grad() + def offload_optimizer(self): + if len(self.offload_states) > 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + self.offload_states[key] = value.device + state[key] = value.to('cpu', non_blocking=True) + + @torch.no_grad() + def load_optimizer(self): + if len(self.offload_states) == 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to(self.offload_states[key], non_blocking=True) + self.offload_states.clear() + + @contextmanager + def multi_turn_completion_length_context(self): + """ + Context manager that temporarily adjusts the engine's max length handling + for multi-turn generation scenarios. + + Ensures the total sequence length (prompt + completion) never exceeds: + min(original_max_len, prompt_tokens + max_completion_length) + """ + if not (self.multi_turn_func and self.infer_rank >= 0) or self.is_external_vllm: + yield + return + + original_fn = self.engine.set_default_max_tokens + original_max_len = self.engine.max_model_len + + def set_default_max_tokens(_self, request_config: RequestConfig, inputs: InputsType) -> None: + # Calculate required context window + original_max_len = _self.max_model_len or 8192 + if isinstance(inputs, dict): + inputs = [inputs] + prompt_tokens = max(_self._get_num_tokens(inp) for inp in inputs) + + if not hasattr(_self, 'set_grpo_max_model_len'): + # set max model len in first round + max_len = min(original_max_len, prompt_tokens + request_config.max_tokens) + _self.max_model_len = max_len + _self.set_grpo_max_model_len = True + else: + if _self.max_model_len <= prompt_tokens: + # modify max_model_len > prompt_tokens to avoid crash + num_tokens_avoid_crash = 10 + _self.max_model_len = (prompt_tokens + num_tokens_avoid_crash) + request_config.max_tokens = num_tokens_avoid_crash + + original_fn(request_config, inputs) + + try: + self.engine.set_default_max_tokens = MethodType(set_default_max_tokens, self.engine) + yield + finally: + self.engine.set_default_max_tokens = original_fn + self.engine.max_model_len = original_max_len + del self.engine.set_grpo_max_model_len + + def get_resample_dataloader(self) -> DataLoader: + resample_dataset = self.resample_dataset + data_collator = self.data_collator + if isinstance(resample_dataset, datasets.Dataset): + resample_dataset = self._remove_unused_columns(resample_dataset, description='training') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='training') + + dataloader_params = { + 'batch_size': self._train_batch_size * self.args.gradient_accumulation_steps, + 'collate_fn': data_collator, + 'num_workers': self.args.dataloader_num_workers, + 'pin_memory': self.args.dataloader_pin_memory, + 'persistent_workers': self.args.dataloader_persistent_workers, + } + + @contextmanager + def seed_context(self): + seed = self.args.seed + self.args.seed = seed + 1 + yield + self.args.seed = seed + + if not isinstance(resample_dataset, torch.utils.data.IterableDataset): + with seed_context(self): # Set a different seed for resampling than the train_dataset. + dataloader_params['sampler'] = self._get_train_sampler() + dataloader_params['drop_last'] = self.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + dataloader_params['prefetch_factor'] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(resample_dataset, **dataloader_params)) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = 'train' if self.model.training else 'eval' + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == 'eval': + metrics = {f'eval_{key}': val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + table = { + 'step': [str(self.state.global_step)] * len(self._textual_logs['prompt']), + 'prompt': self._textual_logs['prompt'], + 'completion': self._textual_logs['completion'], + **self._textual_logs['rewards'], + } + self.jsonl_writer.append(table) + if self.args.report_to and 'wandb' in self.args.report_to and wandb.run is not None: + import pandas as pd + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=['prompt']) + wandb.log({'completions': wandb.Table(dataframe=df)}) + + def is_async_generate_eval_rollout_done(self): + return not self.eval_flag or not self.eval_queue.empty() + + def is_async_generate_train_rollout_done(self): + return not self.train_queue.empty() diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6d6a7fa3c254acb5ab1ae855de18b0c70ceaaa --- /dev/null +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .cpo_trainer import CPOTrainer + from .dpo_trainer import DPOTrainer + from .grpo_trainer import GRPOTrainer + from .kto_trainer import KTOTrainer + from .orpo_trainer import ORPOTrainer + from .ppo_trainer import PPOTrainer + from .reward_trainer import RewardTrainer + from .rlhf_mixin import RLHFTrainerMixin + from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin +else: + _import_structure = { + 'cpo_trainer': ['CPOTrainer'], + 'dpo_trainer': ['DPOTrainer'], + 'grpo_trainer': ['GRPOTrainer'], + 'kto_trainer': ['KTOTrainer'], + 'orpo_trainer': ['ORPOTrainer'], + 'ppo_trainer': ['PPOTrainer'], + 'reward_trainer': ['RewardTrainer'], + 'rlhf_mixin': ['RLHFTrainerMixin'], + 'utils': ['_split_into_mini_batches', 'patch_lora_merge', 'patch_lora_unmerge', 'round_robin'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/trainers/rlhf_trainer/__pycache__/__init__.cpython-310.pyc b/swift/trainers/rlhf_trainer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f439d8e86effeccdcfbb6fc7394c821717897299 Binary files /dev/null and b/swift/trainers/rlhf_trainer/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/trainers/rlhf_trainer/__pycache__/grpo_trainer.cpython-310.pyc b/swift/trainers/rlhf_trainer/__pycache__/grpo_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98756756ff836dd8d503e2c66133021f6781c8dc Binary files /dev/null and b/swift/trainers/rlhf_trainer/__pycache__/grpo_trainer.cpython-310.pyc differ diff --git a/swift/trainers/rlhf_trainer/__pycache__/rlhf_mixin.cpython-310.pyc b/swift/trainers/rlhf_trainer/__pycache__/rlhf_mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfbe549a20f6ff0295614b390f739ac18684b7b9 Binary files /dev/null and b/swift/trainers/rlhf_trainer/__pycache__/rlhf_mixin.cpython-310.pyc differ diff --git a/swift/trainers/rlhf_trainer/__pycache__/utils.cpython-310.pyc b/swift/trainers/rlhf_trainer/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dbbb22c6024deaeddfb84225b8285e5226e095f Binary files /dev/null and b/swift/trainers/rlhf_trainer/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/trainers/rlhf_trainer/cpo_trainer.py b/swift/trainers/rlhf_trainer/cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..25e4c93578d7d732e581ddfac46420bf5ffe6548 --- /dev/null +++ b/swift/trainers/rlhf_trainer/cpo_trainer.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import warnings +from typing import Optional, Union + +import torch.nn as nn +from transformers import PreTrainedModel +from trl import CPOTrainer as HFCPOTrainer + +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFCPOTrainer.__init__ + + +class CPOTrainer(RLHFTrainerMixin, SwiftMixin, HFCPOTrainer): + + def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): + ref_model = kwargs.get('ref_model') + assert ref_model is None, 'CPO/SimPO does not require a ref_model.' + + args = kwargs['args'] + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.cpo_alpha = args.cpo_alpha + if args.loss_type == 'simpo': + self.simpo_gamma = args.simpo_gamma + if self.cpo_alpha > 0: + warnings.warn('You are using CPO-SimPO method because you set a non-zero cpo_alpha. ' + 'This will result in the CPO-SimPO method ' + '(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). ' + 'If you want to use a pure SimPO method, please set cpo_alpha to 0.') + super().__init__(model, *_args, **kwargs) diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f03af82120fe16d29424383b3c68765d8e90355 --- /dev/null +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -0,0 +1,129 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from peft import PeftModel +from transformers import PreTrainedModel +from trl import DPOTrainer as HFDPOTrainer + +from ..mixin import DataLoaderMixin, SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFDPOTrainer.__init__ + + +class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + *_args, + **kwargs): + from trl.trainer import FDivergenceConstants + args = kwargs['args'] + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + self.is_peft_model = isinstance(model, PeftModel) + + self.ref_adapter_name = args.ref_adapter_name + self.reference_free = args.reference_free + self.use_weighting = False + + super().__init__(model, ref_model, *_args, **kwargs) + + def get_nll_loss(self, logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + return loss_fct(logits, labels) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + batch = batch.copy() + num_examples = batch['labels'].shape[0] // 2 + labels = batch.pop('labels', None) + if self.is_encoder_decoder: + batch['labels'] = labels + + if self.aux_loss_enabled: + batch['output_router_logits'] = True + outputs = model(**batch, use_cache=False) + batch['labels'] = labels + if outputs.logits.shape[1] != labels.shape[1]: + # for llava, the model returns logits for the entire sequence, including the image tokens + # (placed before the text tokens) + outputs.logits = outputs.logits[:, -labels.shape[1]:] + for key in ['input_ids', 'attention_mask', 'labels']: + batch[f'concatenated_{key}'] = batch.pop(key, None) + if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels + batch['concatenated_input_ids'] = batch['concatenated_labels'] + + all_logits = outputs.logits + + if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]: + # for llava, the model returns logits for the entire sequence, + # including the image tokens (placed before the text tokens) + seq_len = batch['concatenated_labels'].shape[1] + all_logits = all_logits[:, -seq_len:] + + all_logps, size_completion = self.get_batch_logps( + all_logits, + batch['concatenated_labels'], + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + output = {} + + if self.args.rpo_alpha is not None: + labels = batch['concatenated_labels'].clone() + output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples]) + + if self.loss_type == 'ipo': + all_logps = all_logps / size_completion + + output['chosen_logps'] = all_logps[:num_examples] + output['rejected_logps'] = all_logps[num_examples:] + output['mean_chosen_logits'] = all_logits[:num_examples].mean() + output['mean_rejected_logits'] = all_logits[num_examples:].mean() + + if self.aux_loss_enabled: + output['aux_loss'] = outputs.aux_loss + + return output + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + if logits.shape[:-1] != labels.shape: + raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' + 'and labels must have the same shape {labels.shape}') + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + else: + labels = labels.clone() + + loss_mask = labels != label_pad_token_id + + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a18db0f13fa9984c4b8ae4708f5a7f0a8321a063 --- /dev/null +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -0,0 +1,1426 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/trl. +import concurrent.futures +import inspect +import os +import re +import time +from collections import defaultdict, deque +from concurrent.futures import Future +from contextlib import contextmanager +from copy import copy, deepcopy +from dataclasses import asdict, dataclass, field +from math import ceil +from queue import Queue +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import datasets +import numpy as np +import torch +import torch.nn as nn +import transformers +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from packaging import version +from torch.nn import ModuleList +from torch.utils.data import DataLoader +from transformers import PreTrainedModel, TrainerCallback +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.trainer import Trainer +from transformers.trainer_utils import seed_worker +from trl import GRPOTrainer as HFGRPOTrainer +from trl.extras.profiling import profiling_decorator +from trl.models import prepare_deepspeed +from trl.trainer.grpo_trainer import nanmax, nanmin + +from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device +from swift.llm.infer.infer_engine import set_device_context +from swift.llm.template.template_inputs import StdTemplateInputs +from swift.plugin import multi_turns, orms, rm_plugins +from swift.utils import (JsonlWriter, gc_collect, get_device, get_device_count, get_dist_setting, get_logger, + get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available) +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin +from .utils import patch_lora_merge, patch_lora_unmerge, round_robin + +del HFGRPOTrainer.__init__ +del HFGRPOTrainer.log + +logger = get_logger() +if is_wandb_available(): + import wandb + os.environ["WANDB_API_KEY"] = "a7ab128385681b17ad156ad0d8c81ba3e2296164" + os.environ["WANDB_MODE"] = "offline" + +InputsType = List[Dict[str, Union[torch.Tensor, Any]]] +OutputsType = List[List[Tuple[List[Dict], str]]] + + +@contextmanager +def unwrap_model_for_generation( + model, + accelerator, + gather_deepspeed3_params=True, + gather_parameters: List = None, +): + unwrapped_model = accelerator.unwrap_model(model) + if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) + else: + import deepspeed + parameters = [ + parameter for name, parameter in model.named_parameters() + if not gather_parameters or name in gather_parameters + ] + with deepspeed.zero.GatheredParameters(parameters): + from trl.models.utils import remove_hooks + remove_hooks(model) + yield accelerator.unwrap_model(model) + from trl.models.utils import add_hooks + add_hooks(model) + else: + yield unwrapped_model + + +class GRPOCallback(TrainerCallback): + + def __init__(self, trainer): + self.trainer = trainer + + # offload original_modules to cpu, to save memory + def on_train_begin(self, args, state, control, **kwargs): + self.trainer.queue = self.trainer.train_queue + train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader') + self.trainer._prefetch(train_dataloader) + + +@dataclass +class DataCache: + inputs: List[Dict] = field(default_factory=list) + outputs: List[Dict] = field(default_factory=list) + distributed_idx: List[List] = field(default_factory=list) + + +class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer): + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + reward_model: Optional[List[Union[PreTrainedModel, nn.Module]]] = None, + reward_funcs: Optional[List[Union[str, Callable]]] = None, + *_args, + **kwargs): + from swift.trainers.rlhf_arguments import GRPOConfig + args: GRPOConfig = kwargs['args'] + self.args = args + self.train_queue = Queue() + self.eval_queue = Queue() + self.processing_class = kwargs.get('template').tokenizer + self.offload_modules = {} + self.offload_states = {} + _, _, _, local_world_size = get_dist_setting() + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin') + + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + self.reward_model_plugins = [None] * len(self.reward_funcs) + + if reward_model is not None: + reward_template = kwargs.pop('reward_template') + reward_plugins = args.reward_model_plugin + if reward_plugins is None: + reward_plugins = ['default'] * len(reward_model) + assert len(reward_plugins) == len(reward_model), ( + f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match " + f"the number of 'reward_model' ({len(reward_model)}). " + "Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.") + for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template): + # Set encoding mode train(see details in Template.encode). + # Set max_length to None to disable truncation, as the input length has already been truncated earlier. + rm_template.set_mode('train') + rm_template.max_length = None + if rm_plugin not in rm_plugins: + raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin') + self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template)) + self.reward_funcs.append(rm) + self.reward_func_names.append(rm.config._name_or_path.split('/')[-1]) + + if not self.reward_funcs: + raise ValueError('You must specify reward_funcs or reward_model') + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + self.multi_turn_func = None + if self.args.multi_turn_func: + if isinstance(self.args.multi_turn_func, str): + assert self.args.multi_turn_func in multi_turns + multi_turn_func = multi_turns[self.args.multi_turn_func] + self.multi_turn_func = multi_turn_func + else: + self.multi_turn_func = self.args.multi_turn_func + + self.num_generations = args.num_generations + self.temperature = args.temperature + self.loss_type = args.loss_type + model.warnings_issued['estimate_tokens'] = True + kwargs['data_collator'] = lambda features: features + self.shuffle_dataset = args.dataset_shuffle + + use_vllm = args.use_vllm + use_lmdeploy = args.use_lmdeploy + vllm_client = kwargs.pop('vllm_client') # for external vllm + if self.args.tensor_parallel_size > 1 and self.multi_turn_func: + import torch.distributed as dist + rank, _, _, _ = get_dist_setting() + for tp_group in self.tp_group_ranks(): + group = dist.new_group(tp_group) + if rank in tp_group: + self.group = group + + super().__init__(model, ref_model, *_args, **kwargs) + + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps + self._textual_logs = { + 'prompt': deque(maxlen=maxlen), + 'completion': deque(maxlen=maxlen), + 'rewards': defaultdict(lambda: deque(maxlen=maxlen)), + } + + num_processes = self.accelerator.num_processes + self.effective_train_batch_size = effective_batch_size = \ + args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps + possible_values = [n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0] + + if self.num_generations not in possible_values: + raise ValueError( + f'The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x ' + f'{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per ' + f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for ' + f'the number of generations are: {possible_values}.') + if self.args.eval_strategy != 'no': + effective_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [ + n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0 + ] + if self.num_generations not in possible_values: + raise ValueError( + f'The effective eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be ' + f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the ' + 'current effective eval batch size, the valid values for the number of generations are: ' + f'{possible_values}.') + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + self.infer_device = None + self.use_fast_infer = use_vllm or use_lmdeploy # whether to use the PT backend + self.is_external_vllm = use_vllm and args.vllm_server_host is not None + if self.use_fast_infer: + if self.infer_rank >= 0: + fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device + if fast_infer_device[0] == 'auto': + if get_device_count() == 1: + fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it + else: + fast_infer_device = [] + for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()): + fast_infer_device.append(get_device(idx)) + + for _device in fast_infer_device: + # Check that the requested device is available + if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count(): + raise ValueError(f'The requested device for vllm ({_device}) is not available. ' + f'You are likely using vLLM ' + 'without restricting the number of GPUs for training. ' + 'Set the `--num_processes` argument to a ' + 'value lower than the number of GPUs available on your machine—typically, ' + 'reducing it by one is sufficient. ' + f'In your case: `--num_processes {get_device_count() - 1}`.') + + if use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.is_external_vllm: + self.vllm_client = vllm_client + else: + self.engine = self.prepare_vllm(model, fast_infer_device) + self.infer_device = fast_infer_device[self.local_infer_rank] + elif use_lmdeploy: + if not is_lmdeploy_available(): + raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.' + 'Please install LMDeploy with `pip install lmdeploy -U` to use it.') + from swift.llm import LmdeployEngine + from swift.tuners import Swift + with Swift.grpo_context(model, self.template.processor): + fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1]) + self.engine = LmdeployEngine( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + devices=[fast_infer_device], + session_len=args.lmdeploy_session_len, + cache_max_entry_count=args.lmdeploy_cache_max_entry_count, + reload_weights=True) + self.infer_device = fast_infer_device + from lmdeploy.turbomind.turbomind import TurboMind + lmdeploy_engine = self.engine.engine.engine + assert isinstance(lmdeploy_engine, TurboMind), ( + "Currently only LMDeploy's TurboMind backend is supported. " + 'The current model is incompatible - please use vLLM or PyTorch backend instead.') + if not self.is_external_vllm: + self.engine.default_template = copy(self.template) # Avoid thread-unsafe modifications of the mode. + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit + # Avoid thread-unsafe modifications of the mode. + self.request_config = RequestConfig( + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + ) + + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + self.request_config.n = self.args.tensor_parallel_size + if self.infer_rank >= 0: + self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size + + self.model_accepts_loss_kwargs = False + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + + # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + if self.args.async_generate: + self.add_callback(GRPOCallback(self)) + + if self.args.dynamic_sample: + self.resample_dataset = deepcopy(self.train_dataset) + + def cyclic_iter(iterable): + while True: + for x in iterable: + yield x + + self.resample_iterator = cyclic_iter(self.get_resample_dataloader()) + # flag indicating whether the evaluation has started + self.eval_flag = False + + @profiling_decorator + def _prepare_inputs( + self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + mode = 'train' if self.model.training else 'eval' + if mode == 'train': + generate_every = self.args.gradient_accumulation_steps * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch) + self._buffered_inputs = accumulated_local_batch # < this is the change + inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] + self._step += 1 + else: + inputs = self._generate_and_score_completions(accumulated_local_batch) + return inputs + + def split_batches(self): + """Sync weights in batches + Only split LLM layers for now: + 1. N batches for layers + 2. other, embeds, lm_heads in one batch + 3. multi-modal components in one batch + """ + model = self.accelerator.unwrap_model(self.model) + if self.args.move_model_batches is None: + # All in one + return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None] + + model_arch = get_model_arch(model.model_meta.model_arch) + non_llm_parameters = [] + llm_embeds = [] + parameters = [] + pattern = r'\.(\d+)\.' + + layer_count = None + # Get the number of layers in LLM modules + for name, module in model.named_modules(): + if isinstance(module, ModuleList): + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + continue + if isinstance(llm, list): + llm = llm[0] + if name.startswith('base_model'): + name = name.replace('base_model.', '') + if llm in name: + layer_count = len(module) + else: + layer_count = len(module) + assert layer_count is not None, 'Cannot find ModuleList to split modules.' + + n_layers = ceil(layer_count / self.args.move_model_batches) + for _ in range(self.args.move_model_batches): + parameters.append([]) + + def replace_lora(name): + if 'lora_' in name: + return '' + else: + return name.replace('base_layer.', '') + + def remove_lora_and_prefix(names): + names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) + return [n for n in names if n] + + def split_llm(name): + match = re.search(pattern, name) + if match: + number = match.group(1) + group = int(number) // n_layers + parameters[group].append(name) + else: + llm_embeds.append(name) + + for name, parameter in model.named_parameters(): + if 'ref_model' in name: + continue + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + non_llm_parameters.append(name) + elif isinstance(llm, list): + llm = llm[0] + if llm in name: + split_llm(name) + else: + non_llm_parameters.append(name) + else: + split_llm(name) + + if llm_embeds: + parameters.append(llm_embeds) + if non_llm_parameters: + parameters.append(non_llm_parameters) + parameters = [p for p in parameters if p] + parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters] + return parameters, parameters_no_lora + + def prepare_vllm(self, model, fast_infer_device): + from swift.tuners import Swift + from swift.llm import VllmEngine + from swift.llm.infer.infer_engine import GRPOVllmEngine + _, _, _, local_world_size = get_dist_setting() + if self.args.tensor_parallel_size > 1: + vllm_kwargs = {'distributed_executor_backend': 'external_launcher'} + else: + vllm_kwargs = {} + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + # Compatibility with TP + cls = GRPOVllmEngine + engine_kwargs = {'seed': 0} + else: + cls = VllmEngine + engine_kwargs = {} + with Swift.grpo_context(model, self.template.processor): + engine = cls( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + device=fast_infer_device[self.local_infer_rank], + tensor_parallel_size=self.args.tensor_parallel_size, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=self.args.vllm_max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + num_infer_workers=self.args.num_infer_workers, + enable_sleep_mode=self.args.sleep_level > 0, + use_async_engine=False, + max_model_len=self.args.vllm_max_model_len, + engine_kwargs=engine_kwargs, + **vllm_kwargs) + engine.default_template = self.template + return engine + + @property + def infer_rank(self): + if self.is_external_vllm: + # When using external vLLM, only the main process (rank=0) acts as the client. + return 0 if self.accelerator.is_main_process else -1 + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return node_rank * self.args.num_infer_workers + _vllm_rank + if local_rank == -1: + return 0 + return -1 + + @property + def infer_rank_tp_0(self): + # whether is tp rank0, get data from this rank + # vllm needs all tp ranks inputs and sampling params are the same + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank and _vllm_rank % self.args.tensor_parallel_size == 0: + return (node_rank * self.args.num_infer_workers + _vllm_rank // self.args.tensor_parallel_size) + if local_rank == -1: + return 0 + return -1 + + @property + def local_infer_rank(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return _vllm_rank + + return -1 + + def tp_group_ranks(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + return [ + list(range(0, world_size))[i:i + self.args.tensor_parallel_size] + for i in range(0, world_size, self.args.tensor_parallel_size) + ] + + @contextmanager + def _template_context(self, template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + mode = template.mode + if mode in {'vllm', 'pt', 'lmdeploy'}: + template.set_mode('train') + template.max_length = None + loss_scale = template.loss_scale + if self.multi_turn_func: + template.loss_scale = 'default' + try: + yield + finally: + template.loss_scale = loss_scale + template.set_mode(mode) + template.max_length = max_length + + @profiling_decorator + def _move_model_to_vllm_lmdeploy(self): + if self.is_external_vllm: + return super()._move_model_to_vllm() + + from accelerate.utils.other import is_compiled_module + + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + gather_parameters=parameter_group) as unwrapped_model: + + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + with patch_lora_merge(unwrapped_model, parameter_group): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + else: + state_dict = unwrapped_model.state_dict() + if parameter_group_no_lora: + parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + assert len(state_dict) > 0 and all([state.shape != torch.Size([0]) for state in state_dict.values()]) + if self.infer_rank >= 0: + if self.args.async_generate: + self._wait_queue() + if self.args.use_vllm: + llm_model = self.engine.inner_model + else: + llm_model = self.engine.engine.engine + llm_model.load_weights(state_dict.items()) + del state_dict + gc_collect() + # Unmerge the adapter to restore the model to its original state. + # This must be done after loading weights to ensure they correspond to the merged state. + if is_peft_model(unwrapped_model): + with patch_lora_unmerge(unwrapped_model): + unwrapped_model.unmerge_adapter() + + if self.infer_rank >= 0 and self.args.use_vllm and self.args.vllm_enable_prefix_caching: + self.engine.engine.reset_prefix_cache() + + def _wait_queue(self): + while self._queue.empty(): + time.sleep(0.01) + + @staticmethod + def reorder_outputs(outputs, distributed_idx): + index_to_output = {} + current_position = 0 + for output_idx in distributed_idx: + for idx in output_idx: + index_to_output[idx] = outputs[current_position] + current_position += 1 + + return [index_to_output[idx] for idx in sorted(index_to_output.keys())] + + def _infer_multi_turn(self, inputs_slice: np.ndarray, request_config: RequestConfig) -> Union[OutputsType, List]: + """Perform multi-turn or single-turn inference with support for tensor parallelism. + + Args: + inputs_slice: Array of input requests + request_config: Inference configuration parameters + + Returns: + List of outputs where each entry contains: + - List of responses per prompt (length = tensor_parallel_size) + - Each response is a tuple of (message_history, finish_reason) + """ + from swift.llm.infer.protocol import ChatCompletionResponse + rank, _, _, _ = get_dist_setting() + request_config = copy(request_config) + results: List[ChatCompletionResponse] = self._engine_infer( + infer_requests=inputs_slice, request_config=request_config, use_tqdm=False) + prompt_lens = len(inputs_slice) + messages_list = [None] * (len(inputs_slice) * self.args.tensor_parallel_size) + if self.multi_turn_func: + remove_response = True + while len(inputs_slice) > 0: + request_config.n = 1 + if self.infer_rank_tp_0 >= 0 or not self.use_fast_infer: + inputs = [] + cnt = 0 + for i, output in enumerate(results): + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + if remove_response or _input['messages'][-1]['role'] != 'assistant' or not \ + _input['messages'][-1]['content']: + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + else: + _input['messages'][-1]['content'] += choice.message.content + if 'index' not in _input: + _input['index'] = cnt + _input['finish_reason'] = choice.finish_reason + cnt += 1 + inputs.append(_input) + results: List[Dict] = self.multi_turn_func(inputs) # noqa + else: + length = sum([len(results[i].choices) for i in range(len(results))]) + results = [None] * length + + if self.args.tensor_parallel_size > 1: + # avoid duplicate calling in the same tensor parallel group + import torch.distributed as dist + if 'group_src' in inspect.signature(dist.broadcast_object_list).parameters: + dist.broadcast_object_list(results, group_src=0, group=self.group) + else: + global_src = dist.get_global_rank(self.group, 0) + dist.broadcast_object_list(results, src=global_src, group=self.group) + inputs_slice = [r for r in results if not r['finished']] + for idx, r in enumerate(results): + if r['finished'] or r['finish_reason'] == 'length': + messages_list[r['index']] = (r['messages'], r['finish_reason']) + if len(inputs_slice) > 0: + _input_std = [] + for _input in inputs_slice: + _input_std.append(StdTemplateInputs.from_dict(_input)) + # StdTemplateInputs will not remove responses in infer + results = self._engine_infer( + infer_requests=_input_std, request_config=request_config, use_tqdm=False) + # concat responses from the second loop + remove_response = False + + outputs = [] + assert not any([m is None for m in messages_list]) + for i in range(0, len(messages_list), self.args.tensor_parallel_size): + # reformat to [[x, x, x, x] [x, x, x, x]] + # this is the same format of sampling_params.n > 1 + outputs.append(messages_list[i:i + self.args.tensor_parallel_size]) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + else: + # single turn + outputs = [] + for i, output in enumerate(results): + _choices = [] + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + _choices.append((_input['messages'], choice.finish_reason)) + outputs.append(_choices) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + + if self.args.tensor_parallel_size > 1: + if self.infer_rank_tp_0 < 0: + outputs = [] + else: + _outputs = [] + for tp_idx in range(self.args.tensor_parallel_size): + for prompt_idx in range(len(outputs)): + _outputs.append(outputs[prompt_idx][tp_idx]) + outputs = [_outputs] + + return outputs + + def async_infer(self, inputs, inputs_slice, distributed_idx): + + def infer_task(): + with set_device_context(self.infer_device), self.multi_turn_completion_length_context(): + return self._infer_multi_turn(inputs_slice, self.request_config) + + future: Future = self.executor.submit(infer_task) + # pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling + current_queue = self._queue + + def done(_self): + current_queue.put(DataCache(inputs, _self.result(), distributed_idx)) + + future.add_done_callback(done) + + def _prefetch(self, dataloader: DataLoader): + inputs = next(iter(dataloader)) + all_inputs = gather_object(inputs) + nnodes = get_node_setting()[1] + distributed_idx = round_robin(len(all_inputs), nnodes * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + self._queue.put(DataCache(inputs, outputs, distributed_idx)) + else: + self._queue.put(DataCache(inputs, [], distributed_idx)) + if self.accelerator.num_processes > 1: + self.accelerator.wait_for_everyone() + + def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: + """ + This function performs fast inference by managing model and optimizer offloading, + loading weights if necessary, distributing inputs among workers, and generating + completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous + inference modes. + inputs: local inputs + """ + + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + if self.args.offload_model: + self.offload_model() + if self.args.offload_optimizer: + self.offload_optimizer() + if self.args.gc_collect_after_offload: + gc_collect() + # Skip the first wake_up to avoid the warning "Executor is not sleeping" + if self.engine.inner_model_executor.is_sleeping: + self.engine.engine.wake_up() + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy() + self._last_loaded_step = self.state.global_step + all_inputs = gather_object(inputs) + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + # Distribute inputs to different workers + # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker + # 1/3/5 dispatch to the second worker + # trying to shuffle and average the length + nnodes = get_node_setting()[1] + num_workers = 1 if self.is_external_vllm else nnodes + distributed_idx = round_robin(len(all_inputs), num_workers * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + if self.args.async_generate: + self.async_infer(inputs, _input_slice, distributed_idx) + data_cache = self._queue.get() + inputs = data_cache.inputs + outputs = data_cache.outputs + distributed_idx = data_cache.distributed_idx + else: + with set_device_context(self.infer_device): + request_config = copy(self.request_config) + if self.args.tensor_parallel_size > 1: + request_config.seed += self.state.global_step + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + else: + if self.args.async_generate: + # using old model to generate, which will ignore the `clip` of advantages. + self._queue.put(DataCache(inputs, [], distributed_idx)) + data_cache = self._queue.get() + inputs = data_cache.inputs + distributed_idx = data_cache.distributed_idx + outputs = [] + outputs = gather_object(outputs) + if self.args.tensor_parallel_size > 1: + outputs = [[item] for output in outputs for item in output] + if not self.is_external_vllm: + outputs = self.reorder_outputs(outputs, distributed_idx) + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + self.engine.engine.sleep(level=self.args.sleep_level) + if self.args.gc_collect_after_offload: + gc_collect() + if self.args.offload_model: + self.load_model() + if self.args.offload_optimizer: + self.load_optimizer() + return inputs, outputs + + def _generate_completions(self, inputs: InputsType) -> InputsType: + """Generate completions for given inputs using either fast inference or standard PyTorch inference. + + Args: + inputs: List of input examples containing conversation messages. + + Returns: + Modified inputs with generated completions added to the last message + and truncation flag set in 'is_truncated' field. + """ + mode = 'train' if self.model.training else 'eval' + if self.use_fast_infer: + inputs, outputs = self._fast_infer(inputs) + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + outputs = outputs[process_slice] + else: + # pt infer + is_multimodal = self.model.model_meta.is_multimodal + if is_multimodal: + models = self.template.remove_post_encode_hook() + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ), self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(inputs, self.request_config) + if mode == 'train': + # In training mode, ensure the model is returned to train() mode after inference + # This is necessary as pt engines set the model to eval mode during generation + self.model.train() + if is_multimodal: + self.template.register_post_encode_hook(models) + if isinstance(outputs[0][0], list): + outputs = [output[0] for output in outputs] + + for i, output in enumerate(outputs): + inputs[i]['messages'] = output[0][0] + inputs[i]['is_truncated'] = output[0][1] == 'length' + + return inputs + + def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + + inputs = self._generate_completions(inputs) + total_rewards_per_func, total_rewards, completions = self._score_completions(inputs) + mode = 'train' if self.model.training else 'eval' + + if self.args.dynamic_sample and mode == 'train': + # dynamic sampling for std=0 groups + inputs, total_rewards, total_rewards_per_func, completions = \ + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) + + # Prepare final outputs with advantages and other required fields + batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards) + # Log metrics + messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))] + + self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) + + return batch_encoded_inputs + + def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + """Score completions using all reward functions + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + Tuple containing: + - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards + - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards + - completions: List of generated completion strings + """ + device = self.accelerator.device + completions = [example['messages'][-1]['content'] for example in inputs] + rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) + + for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)): + # reward model + if isinstance(reward_func, nn.Module): + rewards_per_func[:, i] = reward_model_plugin(inputs=inputs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs = RowPreprocessor.rows_to_batched(inputs) + output_reward_func = reward_func(completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + total_rewards_per_func = gather(rewards_per_func) + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + + return total_rewards_per_func, total_rewards, completions + + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): + # DAPO https://arxiv.org/abs/2503.14476 + # Replaces samples with zero-reward-variance groups (std=0) + resample_count = 0 + valid_samples = [] + valid_rewards = [] + valid_rewards_per_func = [] + valid_completions = [] + + origin_data = (inputs, rewards, rewards_per_func, completions) + + while resample_count < self.args.max_resample_times: + grouped_rewards = rewards.view(-1, self.num_generations) + group_std = grouped_rewards.std(dim=1) + + valid_mask = (group_std > 0).repeat_interleave(self.num_generations) + all_inputs = gather_object(inputs) + valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask]) + valid_rewards.append(rewards[valid_mask]) + valid_rewards_per_func.append(rewards_per_func[valid_mask]) + valid_completions.extend( + [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask]) + + if len(valid_samples) >= self.effective_train_batch_size: + break + + inputs = next(self.resample_iterator) + inputs = Trainer._prepare_inputs(self, inputs) + inputs = self._generate_completions(inputs) + rewards_per_func, rewards, completions = self._score_completions(inputs) + resample_count += 1 + + if len(valid_samples) >= self.effective_train_batch_size: + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + inputs = valid_samples[:self.effective_train_batch_size][process_slice] + rewards = torch.cat(valid_rewards)[:self.effective_train_batch_size] + rewards_per_func = torch.cat(valid_rewards_per_func)[:self.effective_train_batch_size] + completions = valid_completions[:self.effective_train_batch_size][process_slice] + else: + logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.') + inputs, rewards, rewards_per_func, completions = origin_data + + return inputs, rewards, rewards_per_func, completions + + def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]: + """ + Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + + Args: + inputs (InputsType): List of input samples. Original shape is [gas*bs] where: + - gas: gradient accumulation steps + - bs: per-device batch size + rewards (torch.Tensor): Tensor of rewards corresponding to the inputs. + Shape should match the total number of samples (gas*bs*num_generations) + + Returns: + List[InputsType]: A list of prepared batch inputs, organized as [gas][bs] + """ + # Compute advantages + grouped_rewards = rewards.view(-1, self.num_generations) + mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) + if self.args.scale_rewards: + advantages /= (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + advantages = advantages[process_slice] + + mode = 'train' if self.model.training else 'eval' + bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + gas = self.args.gradient_accumulation_steps if mode == 'train' else 1 + + assert len(inputs) == bs * gas, f'Expected {bs * gas} inputs, got {len(inputs)}' + gas_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(gas)] + + ga_batch_encoded_inputs = [] + template = self.template + + # Split advantages by GAS chunks + advantage_chunks = torch.chunk(advantages, gas) + + for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): + # Encode and process each batch (size=bs) + with self._template_context(template): + batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch] + batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) + + # Process labels and masks + labels = batch_encoded_inputs.pop('labels') + logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs.update({ + 'completion_mask': + labels[:, -logits_to_keep:] != -100, + 'truncated_mask': + torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool), + 'logits_to_keep': + logits_to_keep, + 'advantages': + batch_advantages + }) + + with torch.no_grad(): + batch_encoded_inputs['old_per_token_logps'] = ( + self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None) + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs) + batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps + + ga_batch_encoded_inputs.append(batch_encoded_inputs) + + return ga_batch_encoded_inputs + + def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func): + """Log training/evaluation metrics""" + mode = 'train' if self.model.training else 'eval' + device = self.accelerator.device + + # Calculate completion length metrics + agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs])) + + self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) + self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) + self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) + # Calculate clip ratio + agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device)) + + term_completion_mask = agg_completion_mask[agg_truncated_mask] + clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) + + self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) + + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = rewards_per_func[:, i].mean().item() + self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards) + std_rewards = rewards_per_func[:, i].std().item() + self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) + + # Log overall reward stats + grouped_rewards = rewards.view(-1, self.num_generations) + self._metrics[mode]['reward'].append(grouped_rewards.mean().item()) + self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item()) + + # Log prompt and completion texts + self._textual_logs['prompt'].extend(gather_object(messages)) + self._textual_logs['completion'].extend(gather_object(completions)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training + if isinstance(inputs, list): + assert len(inputs) == 1 + inputs = inputs[0] + completion_mask = inputs['completion_mask'] + truncated_mask = inputs['truncated_mask'] + # apply the completion_mask to exclude loss and metrics for overlong completions + if self.args.overlong_filter and any(truncated_mask): + if all(truncated_mask): + logger.info('All completions are overlong, loss and KL will be zero') + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) + completion_mask = completion_mask * (~truncated_mask) + + per_token_logps = self._get_per_token_logps(model, inputs) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs['ref_per_token_logps'] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + advantages = inputs['advantages'] + old_per_token_logps = inputs['old_per_token_logps'] if self.old_policy else per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + # Log the metrics + mode = 'train' if self.model.training else 'eval' + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum() + high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum() + clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum() + + gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) + self._metrics[mode]['clip_ratio/low_mean'].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/low_min'].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) + self._metrics[mode]['clip_ratio/high_mean'].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/high_max'].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) + self._metrics[mode]['clip_ratio/region_mean'].append(gathered_clip_ratio.nanmean().item()) + + return loss + + # Get the per-token log probabilities for the completions for the model and the reference model + @profiling_decorator + def _get_per_token_logps(self, model, inputs): + from trl.trainer.utils import selective_log_softmax + logits_to_keep = inputs['logits_to_keep'] + input_ids = inputs['input_ids'] + unwrapped_model = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_model): + parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters + else: + parameters = inspect.signature(unwrapped_model.forward).parameters + if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters: + # save memory + return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } + with self._template_context(self.template): + logits = model(**inputs).logits + # exclude the last logit: it corresponds to the next token pred + logits = logits[:, -(logits_to_keep + 1):-1, :] + logits = logits / self.temperature + input_ids = input_ids[:, -logits_to_keep:] + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + def evaluation_loop(self, dataloader, *args, **kwargs): + # Wait for the training rollout to complete + if self.args.async_generate: + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + if self._queue.empty() and self.args.async_generate: + self._prefetch(dataloader) + metric_key_prefix = kwargs['metric_key_prefix'] + output = super().evaluation_loop(dataloader, *args, **kwargs) + metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()} + output.metrics.update(metrics) + self.eval_flag = True + return output + + def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor: + if self.args.async_generate: + # Wait for the eval rollout to complete + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + return super().training_step(model, inputs, num_items_in_batch) + + def _engine_infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + *, + use_tqdm: Optional[bool] = None, + ): + if self.is_external_vllm: + self._process_infer_requests_images(infer_requests) + return self.vllm_client.infer(infer_requests.tolist(), asdict(request_config), use_tqdm=use_tqdm) + else: + return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + + def _process_infer_requests_images(self, infer_requests: List[InferRequest]): + import base64 + if not any('images' in request for request in infer_requests): + return + for request in infer_requests: + if 'images' not in request: + continue + for i, img in enumerate(request['images']): + if 'bytes' in img and img['bytes']: + request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8') + return + + @property + def old_policy(self): + return self.num_iterations > 1 + + @property + def _queue(self): + if self.control.should_evaluate: + return self.eval_queue + else: + return self.train_queue + + @torch.no_grad() + def offload_model(self): + if len(self.offload_modules) > 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, module in unwrapped_model.named_modules(): + if isinstance(module, torch.nn.Embedding): + self.offload_modules[name] = module.weight.device + module.to('cpu') + elif not hasattr(module, 'device'): + pass + elif module.device.type != 'cpu': + self.offload_modules[name] = module.device + module.to('cpu') + + @torch.no_grad() + def load_model(self): + if len(self.offload_modules) == 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, device in self.offload_modules.items(): + module = unwrapped_model.get_submodule(name) + if isinstance(module, torch.nn.Embedding): + module.weight.to(device) + else: + module.to(device) + self.offload_modules.clear() + + @torch.no_grad() + def offload_optimizer(self): + if len(self.offload_states) > 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + self.offload_states[key] = value.device + state[key] = value.to('cpu', non_blocking=True) + + @torch.no_grad() + def load_optimizer(self): + if len(self.offload_states) == 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to(self.offload_states[key], non_blocking=True) + self.offload_states.clear() + + @contextmanager + def multi_turn_completion_length_context(self): + """ + Context manager that temporarily adjusts the engine's max length handling + for multi-turn generation scenarios. + + Ensures the total sequence length (prompt + completion) never exceeds: + min(original_max_len, prompt_tokens + max_completion_length) + """ + if not (self.multi_turn_func and self.infer_rank >= 0) or self.is_external_vllm: + yield + return + + original_fn = self.engine.set_default_max_tokens + original_max_len = self.engine.max_model_len + + def set_default_max_tokens(_self, request_config: RequestConfig, inputs: InputsType) -> None: + # Calculate required context window + original_max_len = _self.max_model_len or 8192 + if isinstance(inputs, dict): + inputs = [inputs] + prompt_tokens = max(_self._get_num_tokens(inp) for inp in inputs) + + if not hasattr(_self, 'set_grpo_max_model_len'): + # set max model len in first round + max_len = min(original_max_len, prompt_tokens + request_config.max_tokens) + _self.max_model_len = max_len + _self.set_grpo_max_model_len = True + else: + if _self.max_model_len <= prompt_tokens: + # modify max_model_len > prompt_tokens to avoid crash + num_tokens_avoid_crash = 10 + _self.max_model_len = (prompt_tokens + num_tokens_avoid_crash) + request_config.max_tokens = num_tokens_avoid_crash + + original_fn(request_config, inputs) + + try: + self.engine.set_default_max_tokens = MethodType(set_default_max_tokens, self.engine) + yield + finally: + self.engine.set_default_max_tokens = original_fn + self.engine.max_model_len = original_max_len + del self.engine.set_grpo_max_model_len + + def get_resample_dataloader(self) -> DataLoader: + resample_dataset = self.resample_dataset + data_collator = self.data_collator + if isinstance(resample_dataset, datasets.Dataset): + resample_dataset = self._remove_unused_columns(resample_dataset, description='training') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='training') + + dataloader_params = { + 'batch_size': self._train_batch_size * self.args.gradient_accumulation_steps, + 'collate_fn': data_collator, + 'num_workers': self.args.dataloader_num_workers, + 'pin_memory': self.args.dataloader_pin_memory, + 'persistent_workers': self.args.dataloader_persistent_workers, + } + + @contextmanager + def seed_context(self): + seed = self.args.seed + self.args.seed = seed + 1 + yield + self.args.seed = seed + + if not isinstance(resample_dataset, torch.utils.data.IterableDataset): + with seed_context(self): # Set a different seed for resampling than the train_dataset. + dataloader_params['sampler'] = self._get_train_sampler() + dataloader_params['drop_last'] = self.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + dataloader_params['prefetch_factor'] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(resample_dataset, **dataloader_params)) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = 'train' if self.model.training else 'eval' + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == 'eval': + metrics = {f'eval_{key}': val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + table = { + 'step': [str(self.state.global_step)] * len(self._textual_logs['prompt']), + 'prompt': self._textual_logs['prompt'], + 'completion': self._textual_logs['completion'], + **self._textual_logs['rewards'], + } + self.jsonl_writer.append(table) + if self.args.report_to and 'wandb' in self.args.report_to and wandb.run is not None: + import pandas as pd + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=['prompt']) + wandb.log({'completions': wandb.Table(dataframe=df)}) + + def is_async_generate_eval_rollout_done(self): + return not self.eval_flag or not self.eval_queue.empty() + + def is_async_generate_train_rollout_done(self): + return not self.train_queue.empty() diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainer/kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f56d0fd6056fe3eb1001bc862bc1f807621264aa --- /dev/null +++ b/swift/trainers/rlhf_trainer/kto_trainer.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from peft import PeftModel +from transformers import PreTrainedModel +from trl import KTOTrainer as HFKTOTrainer + +from swift.utils import get_logger +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +logger = get_logger() + +del HFKTOTrainer.__init__ + + +class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer): + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + *_args, + **kwargs): + args = kwargs['args'] + args.disable_dropout = True + self.desirable_weight = args.desirable_weight + self.undesirable_weight = args.undesirable_weight + self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.is_peft_model = isinstance(model, PeftModel) + if hasattr(args, 'loss_type'): + self.loss_type = args.loss_type + else: + self.loss_type = 'kto' + + self.ref_adapter_name = None + # Not all losses require a KL calculation + self.calculate_KL = True + if self.loss_type in ['apo_zero_unpaired']: + self.calculate_KL = False + super().__init__(model, ref_model, *_args, **kwargs) + + def forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + is_kl = True + + def _add_data_hook(model, args, kwargs): + nonlocal is_kl + if is_kl: + kwargs = {k[len('KL_completion_'):]: v for k, v in batch.items() if k.startswith('KL_completion_')} + else: + kwargs = {k[len('completion_'):]: v for k, v in batch.items() if k.startswith('completion_')} + is_kl = not is_kl + return (), kwargs + + @contextmanager + def _patch_model_call(): + handle = model.register_forward_pre_hook(_add_data_hook, with_kwargs=True, prepend=True) + + try: + yield + finally: + handle.remove() + + with _patch_model_call(): + return super().forward(model, batch) diff --git a/swift/trainers/rlhf_trainer/orpo_trainer.py b/swift/trainers/rlhf_trainer/orpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9792f0d1415d41166f888be65d32bfa08dc2e844 --- /dev/null +++ b/swift/trainers/rlhf_trainer/orpo_trainer.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Optional, Union + +import torch.nn as nn +from transformers import PreTrainedModel +from trl import ORPOTrainer as HFORPOTrainer + +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFORPOTrainer.__init__ + + +class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer): + + def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): + ref_model = kwargs.get('ref_model') + assert ref_model is None, 'ORPO does not require a ref_model.' + super().__init__(model, *_args, **kwargs) diff --git a/swift/trainers/rlhf_trainer/ppo_trainer.py b/swift/trainers/rlhf_trainer/ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc20c882b60a6b416e2306bf0a28a1eb922a5d9 --- /dev/null +++ b/swift/trainers/rlhf_trainer/ppo_trainer.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +from contextlib import contextmanager + +import transformers +from packaging import version +from torch.utils.data import DataLoader +from transformers import PreTrainedModel +from trl import PPOTrainer as HFPPOTrainer + +from swift.utils import patch_getattr +from ..mixin import SwiftMixin + +ppo_trainer_init = HFPPOTrainer.__init__ +del HFPPOTrainer.__init__ + + +class PPOTrainer(SwiftMixin, HFPPOTrainer): + + @staticmethod + @contextmanager + def _patch_dataloader(collate_fn): + __init__ = DataLoader.__init__ + + def __new_init__(self, *args, **kwargs): + kwargs['collate_fn'] = collate_fn + __init__(self, *args, **kwargs) + + DataLoader.__init__ = __new_init__ + try: + yield + finally: + DataLoader.__init__ = __init__ + + def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs): + super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}}) + with self._patch_dataloader(kwargs['data_collator']): + new_kwargs = { + k: v + for k, v in kwargs.items() + if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset'] + } + parameters = inspect.signature(ppo_trainer_init).parameters + if 'config' in parameters: + new_kwargs['config'] = kwargs['args'] + else: + new_kwargs['args'] = kwargs['args'] + if 'processing_class' in parameters: + new_kwargs['processing_class'] = self.tokenizer + else: + new_kwargs['tokenizer'] = self.tokenizer + ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs) + unwrap_model = self.accelerator.unwrap_model(self.model) + patch_getattr(unwrap_model.__class__, 'policy') + + def train(self, *args, **kwargs): + # remove args that are not needed for the HFPPOTrainer + super().train() + + def _save_checkpoint(self, *args, **kwargs): + if version.parse(transformers.__version__) >= version.parse('4.47'): + metrics = kwargs.pop('metrics', None) + trial = kwargs.get('trial') + self._determine_best_metric(metrics=metrics, trial=trial) + return super()._save_checkpoint(*args, **kwargs) diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainer/reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0355343909021eeb6af8c5f2199040302078a272 --- /dev/null +++ b/swift/trainers/rlhf_trainer/reward_trainer.py @@ -0,0 +1,78 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import defaultdict +from typing import Any, Dict, Tuple, Union + +import pandas as pd +import torch +import torch.nn as nn +from accelerate.utils import gather_object +from transformers import PreTrainedModel +from trl import RewardTrainer as HFRewardTrainer +from trl.trainer.utils import print_rich_table + +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFRewardTrainer.__init__ + + +class RewardTrainer(RLHFTrainerMixin, SwiftMixin, HFRewardTrainer): + + def compute_loss(self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + inputs.pop('labels', None) # not use + attention_mask = inputs['attention_mask'] + batch_size = attention_mask.shape[0] // 2 + rewards = model(**inputs).logits + rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0) + if 'margin' in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs['margin']).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + if self.args.center_rewards_coefficient is not None: + loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected)**2) + # compat transformers>=4.46.* + if num_items_in_batch is not None and self.model_accepts_loss_kwargs: + loss /= self.args.gradient_accumulation_steps + if return_outputs: + return loss, { + 'rewards_chosen': rewards_chosen, + 'rewards_rejected': rewards_rejected, + } + return loss + + def visualize_samples(self, num_print_samples: int): + """ + Visualize the reward model logits prediction + + Args: + num_print_samples (`int`, defaults to `4`): + The number of samples to print. Set to `-1` to print all samples. + """ + eval_dataloader = self.get_eval_dataloader() + table = defaultdict(list) + for _, inputs in enumerate(eval_dataloader): + _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) + input_ids = inputs['input_ids'] + attention_mask = inputs['attention_mask'] + sequence_lengths = ((torch.eq(attention_mask, 0).int().argmax(-1) - 1) % attention_mask.shape[1]).tolist() + text = [self.template.safe_decode(tokens[:sequence_lengths[i]]) for i, tokens in enumerate(input_ids)] + batch_size = input_ids.shape[0] // 2 + chosen_text, rejected_text = text[:batch_size], text[batch_size:] + table['chosen_text'].extend(gather_object(chosen_text)) + table['rejected_text'].extend(gather_object(rejected_text)) + table['logits'].extend( + gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])) + if 0 <= num_print_samples <= len(table['chosen_text']): + break + df = pd.DataFrame(table) + if self.accelerator.process_index == 0: + print_rich_table(df[:num_print_samples]) + if 'wandb' in self.args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({'completions': wandb.Table(dataframe=df)}) diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6511c373e7d7e80636357abee232d2e5e7c44f --- /dev/null +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.integrations import is_deepspeed_zero3_enabled + +try: + from trl import AutoModelForCausalLMWithValueHead +except (ImportError, RuntimeError): + AutoModelForCausalLMWithValueHead = None + + +class RLHFTrainerMixin: + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + *_args, + **kwargs): + from trl.trainer import disable_dropout_in_model + from swift.llm import HfConfigFactory + self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + args = kwargs['args'] + self.beta = getattr(args, 'beta', 0.0) + if getattr(args, 'disable_dropout', False): + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.is_encoder_decoder = kwargs['template'].is_encoder_decoder + self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False) + self._peft_has_been_casted_to_bf16 = False + self.generate_during_eval = getattr(args, 'generate_during_eval', False) + if self.is_encoder_decoder: + self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') + self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + # not use + self.is_vision_model = False + self.label_pad_token_id = -100 + self.use_dpo_data_collator = True + super().__init__(model, *_args, **kwargs) + if is_deepspeed_zero3_enabled() and ref_model is not None: + try: + from trl.models.utils import prepare_deepspeed + except ImportError as e: + raise ImportError('Please install trl>=0.14 via `pip install "trl>=0.14"`') from e + prepare_deepspeed(self.ref_model, self.accelerator) # Does not wrap DeepSpeedEngine + self.padding_value = self.tokenizer.pad_token_id + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + model_kwargs = batch.copy() + labels = model_kwargs.pop('labels', None) + if self.is_encoder_decoder: + model_kwargs['labels'] = labels + + if self.aux_loss_enabled: + model_kwargs['output_router_logits'] = True + outputs = model(**model_kwargs, use_cache=False) + model_kwargs['labels'] = labels + model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) # just get shape + if outputs.logits.shape[1] != labels.shape[1]: + # for llava, the model returns logits for the entire sequence, including the image tokens + # (placed before the text tokens) + outputs.logits = outputs.logits[:, -labels.shape[1]:] + for key in ['input_ids', 'attention_mask', 'labels']: + model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None) + if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels + model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels'] + + @contextmanager + def _patch_concatenated_forward(): + _old_concatenated_inputs = self.concatenated_inputs + _old_model_call = model.__class__.__call__ + self.concatenated_inputs = lambda *args, **kwargs: model_kwargs + model.__class__.__call__ = lambda *args, **kwargs: outputs + try: + yield + finally: + self.concatenated_inputs = _old_concatenated_inputs + model.__class__.__call__ = _old_model_call + + with _patch_concatenated_forward(): + return super().concatenated_forward(model, model_kwargs) + + def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, *args, **kwargs): + if self.is_encoder_decoder: + labels = labels.clone() # fix trl bug + return super().get_batch_logps(logits, labels, *args, **kwargs) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + res = super().compute_loss(model, inputs, return_outputs=return_outputs) + # compat transformers>=4.46.* + if num_items_in_batch is not None and self.model_accepts_loss_kwargs: + loss = res[0] if return_outputs else res + loss /= self.args.gradient_accumulation_steps + return (loss, res[1:]) if return_outputs else loss + return res diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb00f2c6fab2bf448a70ae87e811e53c57b3acdf --- /dev/null +++ b/swift/trainers/rlhf_trainer/utils.py @@ -0,0 +1,132 @@ +from contextlib import contextmanager +from types import MethodType +from typing import Any, List, Optional + +import torch +from peft.tuners import lora +from peft.tuners.lora import LoraLayer + + +def round_robin(num_reqs, num_workers): + """Distribute requests evenly across workers using round-robin algorithm. + + Args: + num_reqs (int): Total number of requests to distribute + num_workers (int): Number of available workers + + Returns: + list: A list of lists where each sublist contains the request indices + assigned to that particular node + """ + distribution = [[] for _ in range(num_workers)] + for idx in range(num_reqs): + worker_id = idx % num_workers + distribution[worker_id].append(idx) + return distribution + + +@contextmanager +def patch_lora_merge(model, parameter_group=None): + """Patch LoraLayer's merge and get_delta_weight methods for controlled merging. + + Args: + model: The PEFT model to patch + parameter_group: Optional list of parameter names to restrict merging + + Yields: + The patched model (context manager ensures cleanup) + """ + from peft.tuners.tuners_utils import check_adapters_to_merge + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + if parameter_group and all(self.name not in pg for pg in parameter_group): + return # Skip if not in target parameter group + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if self.use_dora.get(active_adapter, False): + self.lora_magnitude_vector[active_adapter].weight.data = \ + self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device) + + return self.merge_origin(safe_merge, adapter_names) + + def get_delta_weight(self, adapter) -> torch.Tensor: + # Ensure tensors are on correct device + if isinstance(self, lora.Embedding): + self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device) + self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device) + else: + self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device) + self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device) + return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device) + + def _cache_pop(self, key: str) -> Any: + value = self._caches.pop(key).to(self.base_layer.weight.device) + return value + + # Patch all LoraLayer instances + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + module.name = name + if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'): + module.merge_origin = module.merge + module.merge = MethodType(merge, module) + module.get_delta_weight_origin = module.get_delta_weight + module.get_delta_weight = MethodType(get_delta_weight, module) + module._cache_pop_origin = module._cache_pop + module._cache_pop = MethodType(_cache_pop, module) + + try: + yield model + finally: + # Cleanup: restore original methods + for module in model.modules(): + if isinstance(module, LoraLayer): + if hasattr(module, 'merge_origin'): + module.merge = module.merge_origin + del module.merge_origin + module.get_delta_weight = module.get_delta_weight_origin + del module.get_delta_weight_origin + module._cache_pop = module._cache_pop_origin + del module._cache_pop_origin + + +@contextmanager +def patch_lora_unmerge(model): + """Patch the unmerge method to ensure proper device handling.""" + + def _cache_pop_patched(self, key: str) -> Any: + value = self._caches.pop(key).to(self.base_layer.weight.device) + return value + + def unmerge_patched(self): + if not self.merged: + return + # Move magnitude vectors to correct device first + for adapter in list(self.merged_adapters): + if self.use_dora.get(adapter, False): + self.lora_magnitude_vector[adapter].weight.data = \ + self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device) + + return self.unmerge_origin() + + for module in model.modules(): + if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'): + module.unmerge_origin = module.unmerge + module.unmerge = MethodType(unmerge_patched, module) + module._cache_pop_origin = module._cache_pop + module._cache_pop = MethodType(_cache_pop_patched, module) + + try: + yield model + finally: + for module in model.modules(): + if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'): + module.unmerge = module.unmerge_origin + del module.unmerge_origin + module._cache_pop = module._cache_pop_origin + del module._cache_pop_origin diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..93d4b999ec621e032102b71b459a9443b692cad0 --- /dev/null +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -0,0 +1,212 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# Code partially sourced from Hugging Face TRL + +import atexit +import logging +import time +from typing import List, Optional + +import requests +import torch +from dacite import from_dict +from requests import ConnectionError +from torch import nn + +from swift.llm import AdapterRequest, InferRequest, Template +from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig +from swift.plugin import Metric +from swift.utils import is_vllm_ascend_available, is_vllm_available + +if is_vllm_available(): + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator # noqa + +logger = logging.getLogger(__name__) + + +class VLLMClient: + """ + A client class to interact with a vLLM server. + + This class provides methods to infer completions, initialize and manage weight update groups, and update model + weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. + + Args: + host (`str`, *optional*, defaults to `"0.0.0.0"`): + IP address of the vLLM server. + server_port (`int`, *optional*, defaults to `8000`): + Port number of the vLLM server. + group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. + connection_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds to wait for the server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + """ + + def __init__(self, + host: str = '0.0.0.0', + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0): + if not is_vllm_available(): + raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.') + + self.session = requests.Session() + self.host = host + self.server_port = server_port + self.group_port = group_port + self.check_server(connection_timeout) # check server and fail after timeout + + def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): + """ + Check server availability with retries on failure, within a total timeout duration. If the server is not up + after the total timeout duration, raise a `ConnectionError`. + + Args: + retry_interval (`float`, *optional*, defaults to `2.0`): + Interval in seconds between retries. + total_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds. + """ + url = f'http://{self.host}:{self.server_port}/health/' + start_time = time.time() # Record the start time + + while True: + try: + response = requests.get(url) + except requests.exceptions.RequestException as exc: + # Check if the total timeout duration has passed + elapsed_time = time.time() - start_time + if elapsed_time >= total_timeout: + raise ConnectionError( + f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} " + 'seconds. Make sure the server is running by running `swift deploy`.') from exc + else: + if response.status_code == 200: + logger.info('Server is up!') + return None + + # Retry logic: wait before trying again + logger.info(f'Server is not up yet. Retrying in {retry_interval} seconds...') + time.sleep(retry_interval) + + def infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + use_tqdm: Optional[bool] = None, + adapter_request: Optional[AdapterRequest] = None, + ): + url = f'http://{self.host}:{self.server_port}/infer/' + response = self.session.post( + url, + json={ + 'infer_requests': infer_requests, + 'request_config': request_config, + 'metrics': metrics, + 'template': template, + 'use_tqdm': use_tqdm, + 'adapter_request': adapter_request, + }, + ) + if response.status_code == 200: + return [from_dict(data_class=ChatCompletionResponse, data=resp) for resp in response.json()] + else: + raise Exception(f'Request failed: {response.status_code}, {response.text}') + + def init_communicator(self): + """ + Initializes the weight update group in a distributed setup for model synchronization. + """ + # Get the tensor parallel size from the server + url = f'http://{self.host}:{self.server_port}/get_world_size/' + response = requests.get(url) + if response.status_code == 200: + vllm_world_size = response.json()['world_size'] + else: + raise Exception(f'Request failed: {response.status_code}, {response.text}') + + world_size = vllm_world_size + 1 # add the client to the world + self.rank = vllm_world_size # the client's rank is the last process + + # Initialize weight update group + url = f'http://{self.host}:{self.server_port}/init_communicator/' + # In the server side, the host is set to 0.0.0.0 + response = self.session.post(url, json={'host': '0.0.0.0', 'port': self.group_port, 'world_size': world_size}) + if response.status_code != 200: + raise Exception(f'Request failed: {response.status_code}, {response.text}') + + # Brief delay to allow server initialization. While not strictly required (client socket will retry on + # connection failure), this prevents log warnings like: + # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 + time.sleep(0.1) + + # Set up the communication group for weight broadcasting + pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) + self.pynccl_comm = PyNcclCommunicator(pg, device=0) + + # When the client object is deleted, close the weight update group + atexit.register(self.close_communicator) + + def update_named_param(self, name: str, weights: torch.Tensor): + """ + Updates a specific named parameter in the model and broadcasts it to other processes. + + Args: + name (`str`): + Name of the layer whose weights are being updated. + weights (`torch.Tensor`): + Tensor containing the updated weights. + """ + dtype, shape = str(weights.dtype), tuple(weights.shape) + url = f'http://{self.host}:{self.server_port}/update_named_param/' + response = self.session.post(url, json={'name': name, 'dtype': dtype, 'shape': shape}) + if response.status_code != 200: + raise Exception(f'Request failed: {response.status_code}, {response.text}') + + # Broadcast the weights to the other processes + self.pynccl_comm.broadcast(weights, src=self.rank) + self.pynccl_comm.group.barrier() + + def update_model_params(self, model: nn.Module): + """ + Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. + + Args: + model (`nn.Module`): + Model whose parameters (weights/biases) are to be updated. + """ + for name, param in model.named_parameters(): + # Update each parameter individually + self.update_named_param(name, param.data) + + def reset_prefix_cache(self): + """ + Resets the prefix cache for the model. + """ + url = f'http://{self.host}:{self.server_port}/reset_prefix_cache/' + response = self.session.post(url) + if response.status_code != 200: + raise Exception(f'Request failed: {response.status_code}, {response.text}') + + def close_communicator(self): + """ + Closes the weight update group and cleans up the communication group. + """ + url = f'http://{self.host}:{self.server_port}/close_communicator/' + + try: + response = self.session.post(url) + except ConnectionError: + # The server might be already down, so we don't need to close the communicator + pass + else: + if response.status_code != 200: + raise Exception(f'Request failed: {response.status_code}, {response.text}') diff --git a/swift/trainers/sequence_parallel/__init__.py b/swift/trainers/sequence_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0602f84075025d705b8910801b030f2591e77804 --- /dev/null +++ b/swift/trainers/sequence_parallel/__init__.py @@ -0,0 +1,8 @@ +import os + +if os.environ.get('SEQUENCE_PARALLEL_IMPL', 'ulysses') == 'xtuner': + from .xtuner import XTuner + sequence_parallel = XTuner() +else: + from .ulysses import Ulysses + sequence_parallel = Ulysses() diff --git a/swift/trainers/sequence_parallel/base.py b/swift/trainers/sequence_parallel/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5d3b055c84181779eb5f8a6736698e7383c09f --- /dev/null +++ b/swift/trainers/sequence_parallel/base.py @@ -0,0 +1,45 @@ +import abc +from abc import abstractmethod + + +class SequenceParallel(abc.ABC): + + @abstractmethod + def init_sequence_parallel(self, size): + pass + + @abstractmethod + def prepare_model(self, model, tokenizer, split_in_forward): + pass + + @abstractmethod + def pad_and_split_inputs(self, + tokenizer, + input_ids, + input_embeds, + labels, + position_ids, + attention_mask, + loss_scale, + embed_tokens=None): + pass + + @abstractmethod + def reduce_outputs(self, loss, labels): + pass + + @property + def sp_group(self): + return None + + @abstractmethod + def world_size(self): + pass + + @abstractmethod + def prepare_trainer(self, trainer): + pass + + @abstractmethod + def get_dataloader(self, trainer, dataset, batch_size): + pass diff --git a/swift/trainers/sequence_parallel/ulysses.py b/swift/trainers/sequence_parallel/ulysses.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c415c15e5d9a3009d3b4191f301bf7552e34b0 --- /dev/null +++ b/swift/trainers/sequence_parallel/ulysses.py @@ -0,0 +1,594 @@ +import math +from functools import partial +from types import MethodType +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import datasets +import numpy as np +import torch +import torch.distributed as dist +from peft import PeftModel +from torch.distributed.device_mesh import init_device_mesh +from torch.nn import CrossEntropyLoss +from torch.utils.data import DataLoader, Sampler +from transformers.trainer_utils import seed_worker + +from swift.llm import DataLoaderDispatcher, get_model_arch +from swift.tuners import SwiftModel +from swift.utils import get_current_device, get_device, get_dist_setting +from .base import SequenceParallel + + +class GatherLoss(torch.autograd.Function): + """Gather loss from sequence group""" + + @staticmethod + def forward(ctx, loss, labels, process_group, gather_idx=None): + """ + Args: + loss: loss tensor after splitting + labels: labels tensor after splitting + process_group: the sequence parallel group + gather_idx: gather the tensors on this dim + """ + ctx.process_group = process_group + shape0 = labels.shape[0] + ctx.scatter_shape = labels.shape[gather_idx or 0] + ctx.gather_idx = gather_idx or 0 + world_size = dist.get_world_size(group=process_group) # the sp world size + output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device) + # gather all from sp group + dist.all_gather_into_tensor(output, loss, group=process_group) + if gather_idx is not None: + output = torch.cat(output.split(shape0, dim=0), dim=gather_idx) + labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device) + dist.all_gather_into_tensor(labels_output, labels, group=process_group) + if gather_idx is not None: + labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx) + return output, labels_output + + @staticmethod + def backward(ctx, *grad_output): + _grad = grad_output[0] * dist.get_world_size(group=ctx.process_group) + return _grad.split( + ctx.scatter_shape, dim=ctx.gather_idx)[dist.get_rank(ctx.process_group)].contiguous(), None, None, None + + +# For nll loss +def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None, process_group=None) -> torch.Tensor: + if hasattr(outputs, 'logits'): + logits = outputs.logits + else: + logits = outputs + device = logits.device + logits = logits.view(-1, logits.shape[-1]) + labels = labels.flatten().to(device) + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction='none') + # flatten loss + loss = loss_fct(logits, labels) + + if loss_scale is not None: + loss_scale = loss_scale.flatten().to(loss.device) + loss = (loss_scale * loss) + loss, labels = GatherLoss.apply(loss, labels, process_group) + loss = loss[labels != -100].sum() + if num_items_in_batch is None: + loss = loss / (labels != -100).sum() + else: + loss = loss / num_items_in_batch + return loss + + +# For DPO +def get_batch_logps(logits: torch.FloatTensor, + labels: torch.LongTensor, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + process_group=None) -> Tuple[torch.FloatTensor, torch.LongTensor]: + labels = labels.clone() # No need to shift, pad and split has shifted the inputs. + loss_mask = labels != label_pad_token_id + labels[labels == label_pad_token_id] = 0 + labels = labels.to(logits.device) + loss_mask = loss_mask.to(logits.device) + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, process_group, 1) + return (total_per_token_logps * total_loss_mask).sum(-1), total_loss_mask.sum(-1) + + +class UlyssesSampler(Sampler): + + # Code borrowed from mmengine + def __init__(self, ulysses, dataset, shuffle: bool = True, seed=None, round_up: bool = True) -> None: + self.ulysses = ulysses + rank = dist.get_rank(ulysses.device_mesh['data'].get_group()) + world_size = ulysses.device_mesh['data'].size() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + assert seed is not None + self.seed = seed + self.epoch = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil(len(self.dataset) / world_size) + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil((len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + def __iter__(self) -> Iterator[int]: + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + if self.round_up: + indices = (indices * int(self.total_size / len(indices) + 1))[:self.total_size] + + indices = indices[self.rank:self.total_size:self.world_size] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class UlyssesDispatcher(DataLoaderDispatcher): + + def __init__(self, base_dataloader, ulysses): + super().__init__(base_dataloader) + self.ulysses = ulysses + + def __iter__(self): + base_iter = iter(self.base_dataloader) + while True: + data = None + try: + for i in range(self.ulysses.dp_world_size): + data = next(base_iter) + if i == self.ulysses.dp_rank: + break + except StopIteration: + pass + if data is None: + break + yield data + + +# Code borrowed from deepspeed, here is why: +# 1. Reduce the dependency +# 2. The original code is complex +def _generate_layout_params(scatter_idx, seq_world_size, input): + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] + pre_all2all_permute_idx = (1, 0, 2, 3, 4) + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible ' + f'by the sequence parallel size ({seq_world_size})!') + pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + + post_all2all_permute_idx = (1, 0, 2, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] + + return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape + + +def post_all2all(permute_idx, res_shape): + """ + Post-processing function for `all2all` communication. + """ + + def post_func(input): + if permute_idx is not None: + input = input.permute(permute_idx).contiguous() + output = input.reshape(res_shape).contiguous() + + return output + + return post_func + + +def pre_all2all_fun(permute_idx, inp_shape, input): + """ + Pre-processing function for `all2all` communication. + """ + input_t = input.reshape(inp_shape).contiguous() + if permute_idx is not None: + input_t = input_t.permute(permute_idx).contiguous() + return input_t + + +def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs): + seq_world_size = dist.get_world_size(group) + num_heads = input.shape[2] + if num_heads % seq_world_size != 0 and not scatter_idx < 2: + raise NotImplementedError + pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = ( + _generate_layout_params(scatter_idx, seq_world_size, input)) + + input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) + + post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + res = post_all2all_fun(output) + return res + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + ) -> torch.Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + res = single_all_to_all(input, scatter_idx, gather_idx, group) + return res + + @staticmethod + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: + return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None + + +class DistributedAttention(torch.nn.Module): + + def __init__( + self, + local_attention, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, + gather_idx: int = 1, + ) -> None: + super(DistributedAttention, self).__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, + *args: Any, **kwargs) -> torch.Tensor: + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + position_ids = kwargs.pop('position_ids', None) + if position_ids is not None: + shape0 = position_ids.shape[0] + position_ids_output = torch.empty((shape0 * dist.get_world_size(self.spg), position_ids.shape[1]), + dtype=position_ids.dtype, + device=position_ids.device) + dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.spg) + position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1) + context_layer = self.local_attn( + query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + return output + + +class Ulysses(SequenceParallel): + + def __init__(self): + self.split_in_forward = None + self.dp_world_size = None + self.sp_world_size = None + self.model_dtype = None + self.causal_mask_func = None + self.device_mesh = None + self._inited = False + + def init_sequence_parallel(self, size): + if self._inited: + return + self._inited = True + self.sp_world_size = size + rank, local_rank, world_size, local_world_size = get_dist_setting() + self.dp_world_size = world_size // size + self.device_mesh = init_device_mesh( + get_device().split(':')[0], mesh_shape=(world_size // size, size), mesh_dim_names=['data', 'sequence']) + + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'] = ALL_ATTENTION_FUNCTIONS['flash_attention_2'] + ALL_ATTENTION_FUNCTIONS['sdpa_origin'] = ALL_ATTENTION_FUNCTIONS['sdpa'] + + def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args, + dist_attn, **kwargs): + if dist_attn.local_attn is None: + + def _attention(query, key, value, *args, **kwargs): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, + **kwargs)[0] + + dist_attn.local_attn = _attention + + return dist_attn( + query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask, + *args, **kwargs), None + + def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args, + dist_attn, **kwargs): + if dist_attn.local_attn is None: + + def _attention(query, key, value, *args, **kwargs): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0] + + dist_attn.local_attn = _attention + return dist_attn( + query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask, + *args, **kwargs), None + + ALL_ATTENTION_FUNCTIONS['flash_attention_2'] = partial( + local_flash_attn, dist_attn=DistributedAttention(None, self.sp_group)) + ALL_ATTENTION_FUNCTIONS['sdpa'] = partial(local_sdpa_attn, dist_attn=DistributedAttention(None, self.sp_group)) + + from transformers.modeling_flash_attention_utils import is_flash_attn_available + if is_flash_attn_available(): + # TODO this works for multi-modal models like qwen2.5-vl + # SDPA is not supported, because we need to copy the code to our project, which will bring + # more works for maintaining. + from transformers import modeling_flash_attention_utils + from transformers.modeling_flash_attention_utils import _flash_attention_forward + _distributed_flash_attention = DistributedAttention(_flash_attention_forward, self.sp_group) + + def flash_attention_forward(query_states: torch.Tensor, key_states: torch.Tensor, + value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], q_len, + *args, **kwargs): + return _distributed_flash_attention(query_states, key_states, value_states, attention_mask, + q_len * self.sp_world_size, *args, **kwargs) + + modeling_flash_attention_utils._flash_attention_forward = flash_attention_forward + + def prepare_model(self, model, tokenizer, split_in_forward): + self.split_in_forward = split_in_forward + + def forward(_self, **kwargs): + # Split embedding here for multi-modal + inputs_embeds = kwargs['inputs_embeds'] + position_ids = kwargs['position_ids'] + attention_mask = kwargs['attention_mask'] + _, inputs_embeds, _, position_ids, attention_mask, _ = self.pad_and_split_inputs( + tokenizer, + None, + inputs_embeds, + None, + position_ids, + attention_mask, + None, + embed_tokens=_self.embed_tokens) + kwargs['inputs_embeds'] = inputs_embeds + kwargs['position_ids'] = position_ids + kwargs['attention_mask'] = attention_mask + return _self.forward_origin(**kwargs) + + if isinstance(model, (SwiftModel, PeftModel)): + model = model.model + model_meta = model.model_meta + llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) + if llm_prefix: + llm_model = getattr(model, llm_prefix[0]) + else: + llm_model = model + + if 'CausalLM' not in llm_model.__class__.__name__: + llm_model = model + + base_model = llm_model.model + self.causal_mask_func = base_model._update_causal_mask + if self.split_in_forward: + # for multi modal models + base_model.forward_origin = base_model.forward + base_model.forward = MethodType(forward, base_model) + + self.model_dtype = next(model.parameters()).dtype + + def _pad_sp(self, tensor, padding_value, dim=-1): + # code borrowed from xtuner + length = tensor.shape[dim] + if length % self.sp_world_size == 0: + return tensor + + pad_num = self.sp_world_size - (length % self.sp_world_size) + if not isinstance(padding_value, torch.Tensor): + # ids + pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else + (*tensor.shape[:dim], pad_num)) + pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, pad], dim=dim) + else: + # For embeddings + tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim) + return tensor + + def world_size(self): + return self.sp_world_size + + def _split_sp(self, input, dim: int, sp_group: dist.ProcessGroup): + # code borrowed from xtuner + if self.sp_world_size == 1: + return input + + rank = dist.get_rank(sp_group) + dim_size = input.size(dim) + assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of ' + f'world size ({self.sp_world_size}), cannot split tensor evenly') + + tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim) + output = tensor_list[rank].contiguous() + + return output + + def pad_and_split_inputs(self, + tokenizer, + input_ids, + input_embeds, + labels, + position_ids, + attention_mask, + loss_scale, + embed_tokens=None): + sp_group = self.sp_group + split_inputs = False + if (input_ids is not None and not self.split_in_forward) or input_embeds is not None: + # Whether split the model inputs + # cannot split input_ids for multi-modal models + split_inputs = True + if input_ids is not None and split_inputs: + input_ids = self._pad_sp(input_ids, padding_value=tokenizer.pad_token_id, dim=-1) + if input_embeds is not None: + pad_emb = embed_tokens(torch.tensor(tokenizer.pad_token_id).to(embed_tokens.weight.device)).unsqueeze(0) + input_embeds = self._pad_sp(input_embeds, padding_value=pad_emb, dim=1) + if position_ids is not None and split_inputs: + position_ids = self._pad_sp(position_ids, padding_value=0, dim=-1) + if split_inputs: + inputs = input_ids if input_ids is not None else input_embeds + attn_shape = inputs.shape[1] # The sequence length + if attention_mask is None: + attention_mask = torch.ones_like(position_ids) + attention_mask = self._pad_sp(attention_mask, padding_value=0, dim=-1) + cache_position = torch.arange(0, attn_shape, device=inputs.device) + # pad attention mask to 4d to avoid calculation errors + attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, None, + None) + if input_ids is not None and split_inputs: + input_ids = self._split_sp(input_ids, dim=1, sp_group=sp_group) + if input_embeds is not None: + input_embeds = self._split_sp(input_embeds, dim=1, sp_group=sp_group) + if position_ids is not None and split_inputs: + position_ids = self._split_sp(position_ids, dim=-1, sp_group=sp_group) + if labels is not None: + labels = self._pad_sp(labels, padding_value=-100, dim=-1) + labels[:, 0] = -100 # make the last invalid, so we do not need to cut the loss of last token + labels = torch.roll(labels, shifts=-1, dims=1) + labels = self._split_sp(labels, dim=1, sp_group=sp_group) + + if loss_scale is not None: + loss_scale = self._pad_sp(loss_scale, padding_value=0., dim=-1) + loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1) + loss_scale = self._split_sp(loss_scale, dim=-1, sp_group=sp_group) + + return input_ids, input_embeds, labels, position_ids, attention_mask, loss_scale + + def reduce_outputs(self, loss, labels): + return loss + + @property + def sp_rank(self): + return dist.get_rank(self.device_mesh['sequence'].get_group()) + + @property + def dp_rank(self): + return dist.get_rank(self.device_mesh['data'].get_group()) + + @property + def sp_group(self): + return self.device_mesh['sequence'].get_group() + + @property + def dp_group(self): + return self.device_mesh['data'].get_group() + + def get_dataloader(self, trainer, dataset, batch_size): + data_collator = trainer.data_collator + if isinstance(dataset, datasets.Dataset): + dataset = trainer._remove_unused_columns(dataset, description='training') + else: + data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training') + if hasattr(dataset, '__len__'): + sampler = UlyssesSampler(self, dataset, seed=42) + dataloader_params = { + 'batch_size': batch_size, + 'collate_fn': data_collator, + 'num_workers': trainer.args.dataloader_num_workers, + 'pin_memory': trainer.args.dataloader_pin_memory, + 'persistent_workers': trainer.args.dataloader_persistent_workers, + } + + if not isinstance(dataset, torch.utils.data.IterableDataset): + dataloader_params['sampler'] = sampler + dataloader_params['drop_last'] = trainer.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + + return DataLoader(dataset, **dataloader_params) + else: + dataloader_params = { + 'collate_fn': data_collator, + 'num_workers': trainer.args.dataloader_num_workers, + 'pin_memory': trainer.args.dataloader_pin_memory, + 'persistent_workers': trainer.args.dataloader_persistent_workers, + 'prefetch_factor': trainer.args.dataloader_prefetch_factor + } + if dist.is_initialized() and dataloader_params['prefetch_factor']: + dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size() + dataloader = DataLoader(dataset, batch_size=batch_size, **dataloader_params) + dataloader = UlyssesDispatcher(dataloader, self) + return dataloader + + def prepare_trainer(self, trainer): + if trainer.train_dataset is None: + raise ValueError('Trainer: training requires a train_dataset.') + + trainer.compute_loss_func = partial(loss_scale_sp_func, process_group=self.sp_group) + if hasattr(trainer, 'get_batch_logps'): + trainer.get_batch_logps = partial(get_batch_logps, process_group=self.sp_group) + if hasattr(trainer, 'get_nll_loss'): + + def rlhf_loss_scale_sp_func(_, *args, **kwargs): + return loss_scale_sp_func(*args, process_group=self.sp_group, **kwargs) + + trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer) + + from swift.plugin import metric + from swift.trainers import mixin + compute_acc_origin = metric.compute_acc + + def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]: + + # Gather preds and labels across the sp group + if isinstance(preds, np.ndarray): + preds = torch.from_numpy(preds).to(get_current_device()) + if isinstance(labels, np.ndarray): + labels = torch.from_numpy(labels).to(get_current_device()) + shape0 = preds.shape[0] + preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]), + dtype=preds.dtype, + device=preds.device) + dist.all_gather_into_tensor(preds_output, preds, group=self.sp_group) + preds_output = torch.cat(preds_output.split(shape0, dim=0), dim=1) + shape0 = labels.shape[0] + labels_output = torch.empty((shape0 * self.sp_world_size, labels.shape[1]), + dtype=labels.dtype, + device=labels.device) + dist.all_gather_into_tensor(labels_output, labels, group=self.sp_group) + labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=1) + # roll back to fit compute_acc + labels_output = torch.roll(labels_output, shifts=1, dims=1) + return compute_acc_origin(preds_output, labels_output, *args, **kwargs) + + metric.compute_acc = compute_acc + mixin.compute_acc = compute_acc diff --git a/swift/trainers/sequence_parallel/xtuner.py b/swift/trainers/sequence_parallel/xtuner.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e43b6bb65aeeee18b6ba40fb42e44db9c4394d --- /dev/null +++ b/swift/trainers/sequence_parallel/xtuner.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + +import datasets +import torch +import torch.distributed as dist +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers.trainer_utils import seed_worker + +from .base import SequenceParallel + + +class XTuner(SequenceParallel): + + @staticmethod + def assert_xtuner_runtime_condition(): + from swift.utils import is_xtuner_available + assert is_xtuner_available(), \ + ('Please install XTuner first to pack dataset to `max_length`.' + '`pip install -U \'xtuner[deepspeed]\'`') + assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.' + + def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any: + self.assert_xtuner_runtime_condition() + if dist.get_rank() == 0: + ds = [i[0] for i in dataset.data] + train_dataset = Dataset.from_list(ds) + from xtuner.dataset.huggingface import pack_dataset + train_dataset = pack_dataset( + train_dataset, + max_length=args.max_length, + use_varlen_attn=False, + shuffle_before_pack=True, + map_num_proc=16) + objects = [train_dataset] + train_dataset.save_to_disk('alpaca_pack') + else: + objects = [None] + dist.broadcast_object_list(objects, src=0) + train_dataset = objects[0] + return train_dataset + + @property + def sp_group(self): + from xtuner.parallel.sequence import get_sequence_parallel_group + return get_sequence_parallel_group() + + def init_sequence_parallel(self, size): + self.assert_xtuner_runtime_condition() + from xtuner.parallel.sequence import init_sequence_parallel + init_sequence_parallel(size) + + def prepare_model(self, model, tokenizer, split_in_forward): + self.assert_xtuner_runtime_condition() + from xtuner.model.modules.dispatch import dispatch_modules + dispatch_modules(model) + + def pad_and_split_inputs(self, + tokenizer, + input_ids, + input_embeds, + labels, + position_ids, + attention_mask, + loss_scale, + embed_tokens=None): + self.assert_xtuner_runtime_condition() + from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel, + get_sequence_parallel_group) + input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1) + labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1) + position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1) + if attention_mask is not None: + attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1) + + sp_group = get_sequence_parallel_group() + input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group) + labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group) + position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group) + if attention_mask is not None: + attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group) + if loss_scale is not None: + loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1) + loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group) + + return input_ids, None, labels, position_ids, attention_mask, loss_scale + + def reduce_outputs(self, loss, labels): + from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group) + # reduce loss for logging correctly + num_tokens = (labels != -100).sum() + return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group()) + + def world_size(self): + self.assert_xtuner_runtime_condition() + from xtuner.parallel.sequence import get_sequence_parallel_world_size + return get_sequence_parallel_world_size() + + def prepare_trainer(self, trainer): + pass + + def get_dataloader(self, trainer, dataset, batch_size): + # modified from HFTrainer.get_train_dataloader + # RandomSampler -> SequenceParallelSampler + self.assert_xtuner_runtime_condition() + data_collator = trainer.data_collator + if isinstance(dataset, datasets.Dataset): + dataset = trainer._remove_unused_columns(dataset, description='training') + else: + data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training') + + dataloader_params = { + 'batch_size': batch_size, + 'collate_fn': data_collator, + 'num_workers': trainer.args.dataloader_num_workers, + 'pin_memory': trainer.args.dataloader_pin_memory, + 'persistent_workers': trainer.args.dataloader_persistent_workers, + } + + if not isinstance(dataset, torch.utils.data.IterableDataset): + from xtuner.parallel import SequenceParallelSampler + dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024) + dataloader_params['drop_last'] = trainer.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + + return DataLoader(dataset, **dataloader_params) diff --git a/swift/trainers/torchacc_mixin.py b/swift/trainers/torchacc_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb373794be9040aa4d0bd56b96d9a1fccf14812 --- /dev/null +++ b/swift/trainers/torchacc_mixin.py @@ -0,0 +1,156 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +from typing import Optional + +from transformers import PreTrainedModel, is_datasets_available + +from swift.utils import use_torchacc +from swift.utils.torchacc_utils import (patch_clip_grad_norm, save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint, + ta_eval_dataloader, ta_load_optimizer_and_scheduler, + ta_save_optimizer_and_scheduler, ta_test_dataloader, ta_train_dataloader, + ta_trim_graph) + + +class TorchAccMixin: + + def __init__(self, *args, **kwargs): + if use_torchacc(): + patch_clip_grad_norm(self.accelerator) + super().__init__(*args, **kwargs) + + def get_train_dataloader(self): + if not use_torchacc(): + return super().get_train_dataloader() + + if is_datasets_available(): + import datasets + + if self.train_dataset is None: + raise ValueError('Trainer: training requires a train_dataset.') + + train_dataset = self.train_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description='training') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='training') + + return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args, + self._train_batch_size) + + def get_eval_dataloader(self, eval_dataset=None): + + if not use_torchacc(): + return super().get_eval_dataloader(eval_dataset) + + if is_datasets_available(): + import datasets + + if eval_dataset is None and self.eval_dataset is None: + raise ValueError('Trainer: evaluation requires an eval_dataset.') + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation') + + return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args) + + def get_test_dataloader(self, test_dataset): + + if not use_torchacc(): + return super().get_test_dataloader(test_dataset) + + if is_datasets_available(): + import datasets + + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description='test') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='test') + + return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args) + + def _save_tpu(self, output_dir: Optional[str] = None): + + if not use_torchacc(): + return super()._save_tpu(output_dir) + + import torch_xla.core.xla_model as xm + + # Compatible with swift and peft + output_dir = output_dir if output_dir is not None else self.args.output_dir + + if xm.is_master_ordinal(local=False): + os.makedirs(output_dir, exist_ok=True) + # configuration.json + model_dir = getattr(self.model, 'model_dir', None) + if model_dir is not None: + src_path = os.path.join(model_dir, 'configuration.json') + dst_path = os.path.join(output_dir, 'configuration.json') + if os.path.exists(src_path): + shutil.copy(src_path, dst_path) + else: + self._create_configuration_file(self.model, output_dir) + self._save_sft_args(output_dir) + # generation_config + generation_config = getattr(self.args, 'generation_config', None) + if generation_config is not None: + generation_config.save_pretrained(output_dir) + + # model + if self.args.fsdp_num > 1: + save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir) + else: + save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir) + + # additional files + if xm.is_master_ordinal(local=False): + if self.args is not None and self.args.sft_type == 'full': + additional_files = getattr(self.args, 'additional_saved_files', + None) or [] + ['preprocessor_config.json'] + if model_dir is not None: + for file in additional_files: + src_path = os.path.join(model_dir, file) + dst_path = os.path.join(output_dir, file) + if os.path.isfile(src_path): + shutil.copy(src_path, dst_path) + elif os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + + def _load_optimizer_and_scheduler(self, checkpoint): + + if not use_torchacc() or self.args.fsdp_num == 1: + return super()._load_optimizer_and_scheduler(checkpoint) + + self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, + checkpoint, self.args.device) + + def _save_optimizer_and_scheduler(self, output_dir): + if not use_torchacc() or not self.args.fsdp_num == 1: + return super()._save_optimizer_and_scheduler(output_dir) + + return ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir) + + def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs): + if use_torchacc() and self.control.should_log: + ta_trim_graph() + super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs) + + def _load_from_checkpoint(self, resume_from_checkpoint: str, model=None) -> None: + if use_torchacc(): + if model is None: + model = self.model + # Loading checkpoint of TorchAcc has been done in tuner.py when + # sft_type is 'full'. + if self.args.fsdp_num > 1: + model = model._get_underlay_model().module.module + if isinstance(model, PreTrainedModel): + return + return super()._load_from_checkpoint(resume_from_checkpoint, model) diff --git a/swift/trainers/trainer_factory.py b/swift/trainers/trainer_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..87657d45d41d4606535549af69da3a9962865b6f --- /dev/null +++ b/swift/trainers/trainer_factory.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import importlib.util +import inspect +from dataclasses import asdict +from typing import Dict + +from swift.utils import get_logger + +logger = get_logger() + + +class TrainerFactory: + TRAINER_MAPPING = { + 'causal_lm': 'swift.trainers.Seq2SeqTrainer', + 'seq_cls': 'swift.trainers.Trainer', + 'embedding': 'swift.trainers.EmbeddingTrainer', + 'dpo': 'swift.trainers.DPOTrainer', + 'orpo': 'swift.trainers.ORPOTrainer', + 'kto': 'swift.trainers.KTOTrainer', + 'cpo': 'swift.trainers.CPOTrainer', + 'rm': 'swift.trainers.RewardTrainer', + 'ppo': 'swift.trainers.PPOTrainer', + 'grpo': 'swift.trainers.GRPOTrainer' + } + + TRAINING_ARGS_MAPPING = { + 'causal_lm': 'swift.trainers.Seq2SeqTrainingArguments', + 'seq_cls': 'swift.trainers.TrainingArguments', + 'embedding': 'swift.trainers.TrainingArguments', + 'dpo': 'swift.trainers.DPOConfig', + 'orpo': 'swift.trainers.ORPOConfig', + 'kto': 'swift.trainers.KTOConfig', + 'cpo': 'swift.trainers.CPOConfig', + 'rm': 'swift.trainers.RewardConfig', + 'ppo': 'swift.trainers.PPOConfig', + 'grpo': 'swift.trainers.GRPOConfig', + } + + @staticmethod + def get_cls(args, mapping: Dict[str, str]): + if hasattr(args, 'rlhf_type'): + train_method = args.rlhf_type + else: + train_method = args.task_type + module_path, class_name = mapping[train_method].rsplit('.', 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + @classmethod + def get_trainer_cls(cls, args): + return cls.get_cls(args, cls.TRAINER_MAPPING) + + @classmethod + def get_training_args(cls, args): + training_args_cls = cls.get_cls(args, cls.TRAINING_ARGS_MAPPING) + args_dict = asdict(args) + parameters = inspect.signature(training_args_cls).parameters + + for k in list(args_dict.keys()): + if k not in parameters: + args_dict.pop(k) + + args._prepare_training_args(args_dict) + return training_args_cls(**args_dict) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py new file mode 100644 index 0000000000000000000000000000000000000000..24bd3e42826cab35f8953daecae37c515c766845 --- /dev/null +++ b/swift/trainers/trainers.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +import os +from contextlib import contextmanager, nullcontext +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from peft import PeftModel +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from transformers import EvalPrediction +from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer +from transformers import Trainer as HfTrainer +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from transformers.utils import is_peft_available + +from swift.utils import JsonlWriter, Serializer, gc_collect +from .arguments import Seq2SeqTrainingArguments, TrainingArguments +from .mixin import DataLoaderMixin, SwiftMixin + + +class Trainer(SwiftMixin, HfTrainer): + args: TrainingArguments + + @contextmanager + def _patch_loss_function(self): + model = self.model + if isinstance(model, PeftModel): + model = model.model + model_cls = model.__class__ + if not hasattr(model_cls, 'loss_function'): + yield + return + + loss_function = model.loss_function + _old_loss_function = model_cls.loss_function + + @staticmethod + @wraps(loss_function) + def new_loss_function(logits, labels, **kwargs): + labels = labels.to(logits.device) # fix device_map + return loss_function(logits=logits, labels=labels, **kwargs) + + model_cls.loss_function = new_loss_function + try: + yield + finally: + model_cls.loss_function = _old_loss_function + + def train(self, *args, **kwargs): + with self._patch_loss_function(): + return super().train(*args, **kwargs) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + loss, outputs = super().compute_loss(model, inputs, return_outputs=True) + if inputs.get('labels') is not None: + self._compute_acc(outputs, inputs['labels']) + if num_items_in_batch is not None and self.model_accepts_loss_kwargs: + loss /= self.args.gradient_accumulation_steps + return (loss, outputs) if return_outputs else loss + + +class EmbeddingTrainer(Trainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.compute_metrics = self.calculate_metric + self.preprocess_logits_for_metrics = None + self.label_names = ['labels'] + + def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]: + from swift.plugin.loss import infonce_loss, calculate_paired_metrics, calculate_infonce_metrics + if self.compute_loss_func is infonce_loss: + return calculate_infonce_metrics(eval_prediction.predictions, eval_prediction.label_ids) + else: + return calculate_paired_metrics(eval_prediction.predictions, eval_prediction.label_ids) + + +class Seq2SeqTrainer(SwiftMixin, DataLoaderMixin, HfSeq2SeqTrainer): + args: Seq2SeqTrainingArguments + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = True # fix transformers>=4.46.2 + if self.args.predict_with_generate: + from swift.llm import PtEngine + self.infer_engine = PtEngine.from_model_template( + self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) + self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) + + @staticmethod + def _predict_data_collator(batch): + return {'_data': batch} + + @contextmanager + def _patch_predict_with_generate(self): + origin_mode = self.template.mode + self.template.set_mode('pt') + is_multimodal = self.model.model_meta.is_multimodal + origin_data_collator = self.data_collator + + if is_multimodal: + models = self.template.remove_post_encode_hook() + self.data_collator = self._predict_data_collator + try: + yield + finally: + if is_multimodal: + self.template.register_post_encode_hook(models) + self.data_collator = origin_data_collator + self.template.set_mode(origin_mode) + + def evaluate(self, *args, **kwargs): + context = self._patch_predict_with_generate() if self.args.predict_with_generate else nullcontext() + with context: + res = super().evaluate(*args, **kwargs) + gc_collect() + return res + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys) + from swift.llm import RequestConfig, InferRequest + data_list = inputs['_data'] + labels_list = [InferRequest.remove_response(data['messages']) for data in data_list] + resp_list = self.infer_engine.infer( + data_list, + RequestConfig(max_tokens=self.model.generation_config.max_new_tokens), + use_tqdm=False, + template=self.template) + + response_list = [] + jsonl_cache = [] + device = self.args.device + for data, resp, labels in zip(data_list, resp_list, labels_list): + response = resp.choices[0].message.content + jsonl_cache.append({'response': response, 'labels': labels, **data}) + response_list.append(Serializer.to_tensor(resp.choices[0].message.content).to(device=device)) + self.jsonl_writer.append(jsonl_cache, gather_obj=True) + labels_list = [Serializer.to_tensor(labels).to(device=device) for labels in labels_list] + response_list = pad_sequence(response_list, batch_first=True, padding_value=0) + labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0) + return None, response_list, labels_list + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + loss_kwargs = {} + labels = None + if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs: + labels = inputs.pop('labels') + + loss_scale = inputs.pop('loss_scale', None) + if loss_scale is not None: + loss_kwargs['loss_scale'] = loss_scale + + with self.template.compute_loss_context(self.model, inputs): + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is None: + labels = inputs['labels'] + outputs.loss = outputs.loss.to(labels.device) + # fix https://github.com/huggingface/transformers/issues/34263 + if num_items_in_batch is not None: + outputs.loss = outputs.loss * (labels[:, 1:] != -100).sum() / num_items_in_batch + + if isinstance(outputs, dict) and 'loss' not in outputs: + raise ValueError( + 'The model did not return a loss from the inputs, only the following keys: ' + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.") + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] + else: + unwrapped_model = self.accelerator.unwrap_model(model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + model_name = unwrapped_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + + if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + loss = sequence_parallel.reduce_outputs(loss, labels) + + if getattr(self.args, 'average_tokens_across_devices', False) and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + + if outputs.logits is not None and labels is not None: + # Liger does not have logits + self._compute_acc(outputs, labels) + return (loss, outputs) if return_outputs else loss diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5540f9f13062a1e974d0c2ed12b71caa2d659d1f --- /dev/null +++ b/swift/trainers/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +import inspect +from types import FunctionType, MethodType +from typing import List, Union + +from peft import PeftModel +from torch.nn import Module + +from swift.utils import get_logger + +logger = get_logger() + + +def can_return_loss(model: Module) -> bool: + """Check if a given model can return loss.""" + if isinstance(model, PeftModel): + signature = inspect.signature(model.model.forward) + else: + signature = inspect.signature(model.forward) + for p in signature.parameters: + if p == 'return_loss' and signature.parameters[p].default is True: + return True + return False + + +def find_labels(model: Module) -> List[str]: + """Find the labels used by a given model.""" + model_name = model.__class__.__name__ + if isinstance(model, PeftModel): + signature = inspect.signature(model.model.forward) + else: + signature = inspect.signature(model.forward) + if 'QuestionAnswering' in model_name: + return [p for p in signature.parameters if 'label' in p or p in ('start_positions', 'end_positions')] + else: + return [p for p in signature.parameters if 'label' in p] + + +def get_function(method_or_function: Union[MethodType, FunctionType]) -> FunctionType: + if isinstance(method_or_function, MethodType): + method_or_function = method_or_function.__func__ + return method_or_function + + +def is_instance_of_ms_model(model: Module) -> bool: + """avoid import modelscope: circular dependency problem""" + for m_cls in model.__class__.__mro__: + cls_name = m_cls.__name__ + cls_module = m_cls.__module__ + if cls_name == 'Model' and cls_module.startswith('modelscope'): + return True + return False diff --git a/swift/tuners/__init__.py b/swift/tuners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35eb48aa897aaeb6426fd28a94cbe561927210d8 --- /dev/null +++ b/swift/tuners/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .adapter import Adapter, AdapterConfig, AdapterModule + from .base import SwiftModel, Swift + from .lora import LoRA, LoRAConfig + from .mapping import SWIFT_MAPPING, SwiftTuners + from .side import Side, SideConfig, SideModule + from .neftune import NEFTune, NEFTuneConfig + from .longlora.longlora import LongLoRAModelType, LongLoRAConfig, LongLoRA + from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule + from .reft import Reft, ReftConfig + from .llamapro import LLaMAPro, LLaMAProConfig + from .peft import (AdaLoraConfig, LoftQConfig, LoHaConfig, LoKrConfig, LoraConfig, VeraConfig, BOFTConfig, + OFTConfig, PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig, + PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, get_peft_config, get_peft_model, + get_peft_model_state_dict) + from .prompt import Prompt, PromptConfig, PromptModule + from .scetuning.scetuning import SCETuning, SCETuningConfig + from .utils import SwiftConfig, SwiftOutput, swift_to_peft_format +else: + _import_structure = { + 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'], + 'base': ['SwiftModel', 'Swift'], + 'lora': ['LoRA', 'LoRAConfig'], + 'longlora.longlora': ['LongLoRAModelType', 'LongLoRAConfig', 'LongLoRA'], + 'mapping': ['SWIFT_MAPPING', 'SwiftTuners'], + 'side': ['Side', 'SideConfig', 'SideModule'], + 'reft': ['Reft', 'ReftConfig'], + 'llamapro': ['LLaMAPro', 'LLaMAProConfig'], + 'neftune': ['NEFTune', 'NEFTuneConfig'], + 'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'], + 'peft': [ + 'AdaLoraConfig', 'LoftQConfig', 'LoHaConfig', 'LoKrConfig', 'LoraConfig', 'VeraConfig', 'BOFTConfig', + 'OFTConfig', 'PeftConfig', 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM', + 'PeftModelForSequenceClassification', 'PeftModelForTokenClassification', 'PrefixTuningConfig', + 'PromptEncoderConfig', 'PromptLearningConfig', 'PromptTuningConfig', 'get_peft_config', 'get_peft_model', + 'get_peft_model_state_dict' + ], + 'prompt': ['Prompt', 'PromptConfig', 'PromptModule'], + 'scetuning': ['SCETuning', 'SCETuningConfig'], + 'utils': ['SwiftConfig', 'SwiftOutput', 'swift_to_peft_format'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/tuners/__pycache__/__init__.cpython-310.pyc b/swift/tuners/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..693db7b964c363549e4be4409a5be5e60476699b Binary files /dev/null and b/swift/tuners/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/adapter.cpython-310.pyc b/swift/tuners/__pycache__/adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00f374b52f3c1e1eb2a28f2071ac705e69a78929 Binary files /dev/null and b/swift/tuners/__pycache__/adapter.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/base.cpython-310.pyc b/swift/tuners/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b84097a290c343f39a2be3f1c1e5152a083142a Binary files /dev/null and b/swift/tuners/__pycache__/base.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/llamapro.cpython-310.pyc b/swift/tuners/__pycache__/llamapro.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41511c6f9e7ba665c2d87b43042d43e65ce24331 Binary files /dev/null and b/swift/tuners/__pycache__/llamapro.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/lora.cpython-310.pyc b/swift/tuners/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a84171d197e6e2b2939e5e82d99cb82d1d3881d Binary files /dev/null and b/swift/tuners/__pycache__/lora.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/lora_layers.cpython-310.pyc b/swift/tuners/__pycache__/lora_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0600d09f11b869249ad2df008d2139c2a3001bc1 Binary files /dev/null and b/swift/tuners/__pycache__/lora_layers.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/mapping.cpython-310.pyc b/swift/tuners/__pycache__/mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e525518b983fa4eb731f8aaf027ebe45f24cbc8 Binary files /dev/null and b/swift/tuners/__pycache__/mapping.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/neftune.cpython-310.pyc b/swift/tuners/__pycache__/neftune.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c081c40193eabb6242134b66461a74790579d0c Binary files /dev/null and b/swift/tuners/__pycache__/neftune.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/part.cpython-310.pyc b/swift/tuners/__pycache__/part.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b0f18b0197792a6df6536f3195526013dff0b5 Binary files /dev/null and b/swift/tuners/__pycache__/part.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/peft.cpython-310.pyc b/swift/tuners/__pycache__/peft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94740e1005b5c49a5230bb83ff2cc126449aaf6a Binary files /dev/null and b/swift/tuners/__pycache__/peft.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/prompt.cpython-310.pyc b/swift/tuners/__pycache__/prompt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0c48eb1728355e91aef02f503b55a4e7c51d53 Binary files /dev/null and b/swift/tuners/__pycache__/prompt.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/reft.cpython-310.pyc b/swift/tuners/__pycache__/reft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4167ff40f93176fd006da9134a72f5c125d3f51c Binary files /dev/null and b/swift/tuners/__pycache__/reft.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/restuning.cpython-310.pyc b/swift/tuners/__pycache__/restuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..099e3a469d2d1f9218c31e7d099c1841e7f6a034 Binary files /dev/null and b/swift/tuners/__pycache__/restuning.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/restuning_components.cpython-310.pyc b/swift/tuners/__pycache__/restuning_components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6cecb67629246f39dd0443f9b84f20891bff238 Binary files /dev/null and b/swift/tuners/__pycache__/restuning_components.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/side.cpython-310.pyc b/swift/tuners/__pycache__/side.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01bbb3202a3c31d85c27b67fa37703698402a1c9 Binary files /dev/null and b/swift/tuners/__pycache__/side.cpython-310.pyc differ diff --git a/swift/tuners/__pycache__/utils.cpython-310.pyc b/swift/tuners/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcb409f3892ea7f10fbdebf336ecd749f593f023 Binary files /dev/null and b/swift/tuners/__pycache__/utils.cpython-310.pyc differ diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..290040b551b5e969eeb7b59bcc7dfd63536b57e3 --- /dev/null +++ b/swift/tuners/adapter.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +import re +import types +from dataclasses import dataclass, field +from typing import List, Union + +import torch +from torch import nn +from transformers.activations import ACT2CLS + +from swift.utils.torch_utils import find_sub_module, get_logger +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class AdapterConfig(SwiftConfig): + """ + The configuration class for the adapter module. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Args: + dim(`int`): The dimension of the hidden states + target_modules(`Union[str, List[str]]`): The feedforward module to be replaced. + in regex format if this argument is str, else will match with `end with` if List[str]. + hidden_pos(`Union[str, int]`): The position of the hidden state to be passed into the adapter, + can be int (args) or str (kwargs) + method_name(`str`): The method to be replaced, default is `forward` + adapter_length: The length of the adapter length (intermediate length) + act_layer: The activation layer of the adapter + """ + + dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'}) + + target_modules: Union[str, List[str]] = field( + default=None, + metadata={ + 'help': + 'The feedforward module to be replaced. in regex format if this argument is str, ' + 'else will match with `end with` if List[str].' + }) + + hidden_pos: Union[str, int] = field( + default=None, + metadata={ + 'help': 'The position of the hidden state to be passed into the adapter, can be int (args) or str (kwargs)' + }) + + method_name: str = field(default='forward', metadata={'help': 'The method to be replaced, default is `forward`'}) + + adapter_length: int = field( + default=128, metadata={'help': 'The length of the adapter length (intermediate length)'}) + + act_layer: str = field(default='gelu', metadata={'help': 'The activation layer of the adapter'}) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.ADAPTER + + +class Adapter(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: AdapterConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `AdapterConfig`""" + module_keys = [key for key, _ in model.named_modules()] + + for module_key in module_keys: + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, module_key) + else: + target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules) + + if target_module_found: # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + if isinstance(args, (tuple, list, dict)): + if isinstance(config.hidden_pos, int): + _type = type(args) + args = list(args) + args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos]) + args = _type(args) + else: + args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos]) + elif isinstance(args, torch.Tensor): + args = getattr(self, f'adapter_{adapter_name}')(args) + return args + + def _feed_forward_chunk(self, attention_output): + return _forward(self, attention_output) + + # TODO The `config.method_name` method should not be replaced twice. + + setattr(module, f'forward_origin_{adapter_name}', getattr(module, config.method_name)) + num_args_in_forward_chunk_fn = len( + inspect.signature(getattr(module, f'forward_origin_{adapter_name}')).parameters) + if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1: + setattr(module, config.method_name, types.MethodType(_feed_forward_chunk, module)) + else: + setattr(module, config.method_name, types.MethodType(_forward, module)) + adapter_module = AdapterModule(config.dim, adapter_name, module_key, config.adapter_length, + ACT2CLS[config.act_layer]) + setattr(module, f'adapter_{adapter_name}', adapter_module) + logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}') + + def state_dict_callback(state_dict, adapter_name: str, **kwargs): + return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key} + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'adapter_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class AdapterModule(nn.Module, ActivationMixin): + """The implementation of adapter tuning method. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Args: + dim: An integer indicating the embedding dimension. + adapter_length: An integer indicating the length of adapter tuning. + """ + + def __init__( + self, + dim, + adapter_name, + module_key, + adapter_length=None, + act_layer=nn.GELU, + ): + super(AdapterModule, self).__init__() + super(nn.Module, self).__init__(module_key) + self.dim = dim + self.adapter_name = adapter_name + self.adapter_length = adapter_length + self.linear1 = nn.Linear(dim, adapter_length) + self.act = act_layer() + self.linear2 = nn.Linear(adapter_length, dim) + self.init_weights() + self._prepared = False + self.mark_all_sub_modules_as_plugin() + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + self.apply(_init_weights) + + def forward(self, x, identity=None): + if not self.is_activated(self.adapter_name): + return x + if not self._prepared: + self.linear1.to(x.device) + self.act.to(x.device) + self.linear2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.linear1.weight.dtype) + out = self.linear2(self.act(self.linear1(x))) + if identity is None: + identity = x + identity = identity.to(out.dtype) + out = identity + out + return out.to(x_dtype) diff --git a/swift/tuners/base.py b/swift/tuners/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fafc0883abce55d975055352ade4d9f5b3cbdd58 --- /dev/null +++ b/swift/tuners/base.py @@ -0,0 +1,926 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2023-present the HuggingFace Inc. team. +import os +import re +import shutil +import tempfile +from contextlib import contextmanager +from copy import copy +from functools import partial +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Dict, List, Literal, Optional, Union + +import json +import torch +from modelscope import snapshot_download +from peft.utils import CONFIG_NAME +from peft.utils.other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME +from torch import nn +from transformers import Trainer + +from swift.utils.constants import DEFAULT_ADAPTER, SWIFT_TYPE_KEY +from swift.utils.logger import get_logger +from ..utils.torch_utils import get_device_count +from .mapping import SwiftTuners +from .peft import PeftConfig, PeftModel, get_peft_model +from .utils import SwiftConfig, SwiftOutput + +logger = get_logger() + + +class SwiftModel(nn.Module): + """The Swift wrapper model. + + Args: + model (`Union[nn.Module, 'SwiftModel']`) A module to be tuned by Swift. + config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of {adapter_name: SwiftConfig}. + If it's a config class, the adapter_name will be `default` + extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved. + inference_mode (bool, `optional`): Load model at inference mode, default False. + """ + + EXTRA_STATE_DIR = 'extra_states' + + def __init__(self, + model: Union[nn.Module, 'SwiftModel'], + config: Union[SwiftConfig, Dict[str, SwiftConfig]], + extra_state_keys: List[str] = None, + inference_mode: bool = False, + **kwargs): + super().__init__() + self.adapters = {} + self.active_adapters = set() + if isinstance(model, SwiftModel): + self.adapters = model.adapters + extra_state_keys = extra_state_keys or [] + extra_state_keys.extend(model.extra_state_keys) + self.active_adapters = model.active_adapters + model = model.base_model + + self.base_model = model + new_adapters = [] + if isinstance(config, SwiftConfig): + if DEFAULT_ADAPTER not in self.adapters: + all_parts = self._deactivate_all_parts() + self.adapters[DEFAULT_ADAPTER] = self._prepare_model(model, config, DEFAULT_ADAPTER) + for part in all_parts: + self.activate_adapter(part) + new_adapters.append(DEFAULT_ADAPTER) + if self.adapters[DEFAULT_ADAPTER].model is not None: + self.base_model = self.adapters[DEFAULT_ADAPTER].model + else: + logger.warn(f'Adapter {DEFAULT_ADAPTER} has been patched, skip.') + elif isinstance(config, dict): + assert (all(isinstance(c, SwiftConfig) for c in config.values())) + for adapter_name, _config in config.items(): + if adapter_name not in self.adapters: + all_parts = self._deactivate_all_parts() + self.adapters[adapter_name] = self._prepare_model(model, _config, adapter_name) + for part in all_parts: + self.activate_adapter(part) + new_adapters.append(adapter_name) + if self.adapters[adapter_name].model is not None: + self.base_model = self.adapters[adapter_name].model + else: + logger.warn(f'Adapter {adapter_name} has been patched, skip.') + + self.extra_state_keys = extra_state_keys or [] + self.has_additional_modules = any([c.config.has_additional_modules for c in self.adapters.values()]) + + def forward(self, *args, **kwargs): + return self.base_model(*args, **kwargs) + + _parameters = [Parameter('self', Parameter.POSITIONAL_ONLY)] + _parameters += list(signature(self.base_model.forward).parameters.values()) + forward.__signature__ = Signature(_parameters) + self.forward = MethodType(forward, self) + for adapter_name in new_adapters: + self.activate_adapter(adapter_name) + + if inference_mode: + self.eval() + else: + for key, output in self.adapters.items(): + if key in new_adapters: + output.mark_trainable_callback(model) + if self.extra_state_keys: + for n, p in model.named_parameters(): + if any(re.fullmatch(extra_key, n) for extra_key in self.extra_state_keys): + p.requires_grad = True + + @property + def model(self): + return self.base_model + + def _deactivate_all_parts(self): + deactivated = [] + for adapter in self.active_adapters: + output = self.adapters[adapter] + if output.config.swift_type == SwiftTuners.PART: + deactivated.append(adapter) + self.deactivate_adapter(adapter) + return deactivated + + def load_state_dict(self, state_dict, strict=True, adapter_name: str = None): + if adapter_name is not None: + output: SwiftOutput = self.adapters[adapter_name] + if getattr(output.config, 'modules_to_save', None): + for key, value in copy(state_dict).items(): + for module_name in output.config.modules_to_save: + if module_name in key: + state_dict.pop(key) + key = key.replace(module_name, f'{module_name}.modules_to_save.{adapter_name}') + break + state_dict[key] = value + + for key, value in copy(state_dict).items(): + if key.startswith('base_model.model.'): + state_dict.pop(key, None) + key = key[len('base_model.model.'):] + if f'lora_A.{adapter_name}.' not in key and 'lora_A' in key: + state_dict.pop(key, None) + key = key.replace('lora_A.', f'lora_A.{adapter_name}.') + if f'lora_B.{adapter_name}.' not in key and 'lora_B' in key: + state_dict.pop(key, None) + key = key.replace('lora_B.', f'lora_B.{adapter_name}.') + if f'lora_embedding_A.{adapter_name}.' not in key and 'lora_embedding_A' in key: + state_dict.pop(key, None) + key = key.replace('lora_embedding_A.', f'lora_embedding_A.{adapter_name}.') + if f'lora_embedding_B.{adapter_name}.' not in key and 'lora_embedding_B' in key: + state_dict.pop(key, None) + key = key.replace('lora_embedding_B.', f'lora_embedding_B.{adapter_name}.') + state_dict[key] = value + + if output.load_state_dict_callback: + state_dict = output.load_state_dict_callback(self.base_model, adapter_name, state_dict) + + incompatible_keys = self.base_model.load_state_dict(state_dict, False) + if incompatible_keys and len(incompatible_keys[1]) > 0: + logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}') + + def state_dict(self, + *args, + destination=None, + prefix='', + keep_vars=False, + adapter_name: str = None, + peft_format: bool = False, + **kwargs): + """ + Args: + destination (`dict`, `optional`): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (`str`, `optional`): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (`bool`, `optional`): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + adapter_name (`str`, `optional`): The name of the adapter's parameters to be saved, + `None` input will save all adapters. + peft_format (`bool`, `optional`): Save with peft format (extra `base_model.model.` prefix) + **kwargs: + save_adapter(`bool`): Save adapters or not, default True + save_extra_states(`bool`): Save extra states or not, default True + Returns: + The state dict to be saved. + """ + state_dict = kwargs.get('state_dict') + if state_dict is None: + state_dict = self.base_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + state_dict = { + key[len('base_model.'):] if key.startswith('base_model.') else key: value + for key, value in state_dict.items() + } + if not self.has_additional_modules: + return state_dict + + state_dicts = {} + if kwargs.get('save_adapter', True): + for name, output in self.adapters.items(): + if (adapter_name == name or adapter_name is None) and output.config.has_additional_modules: # noqa + state_dicts.update(output.state_dict_callback(state_dict, name)) + modules_to_save_names = [ + sub_name for sub_name, _ in self.base_model.named_parameters() + if f'modules_to_save.{name}' in sub_name + ] + for module_name in modules_to_save_names: + if f'modules_to_save.{name}' in module_name: + state_dicts[module_name.replace(f'modules_to_save.{name}.', '')] = state_dict[module_name] + if kwargs.get('save_extra_states', True): + state_dicts.update({ + k: v + for k, v in state_dict.items() if any( + re.fullmatch(extra_key, k) for extra_key in self.extra_state_keys) + }) + if peft_format: + new_state_dict = {} + for key, value in state_dicts.items(): + if not key.startswith('base_model.model.'): + key = 'base_model.model.' + key + key = key.replace(f'lora_A.{adapter_name}.', 'lora_A.') + key = key.replace(f'lora_B.{adapter_name}.', 'lora_B.') + key = key.replace(f'lora_embedding_A.{adapter_name}.', 'lora_embedding_A.') + key = key.replace(f'lora_embedding_B.{adapter_name}.', 'lora_embedding_B.') + new_state_dict[key] = value + state_dicts = new_state_dict + return state_dicts + + def __getattr__(self, key: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(key) + except AttributeError: + if 'base_model' in dir(self): + return getattr(self.base_model, key) + raise + + @staticmethod + def load_state_file(path, device: Optional[str] = None): + """Load a state dict file by the input path. + + Args: + path: The local dir to load the state file. + + Returns: + The state dict. + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + from safetensors.torch import load_file as safe_load_file + return safe_load_file(filename, device=device) + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + return torch.load(filename, map_location=device) + return None + + def create_optimizer_param_groups(self, **defaults): + all_param_names = set() + param_groups = [] + for output in self.adapters.values(): + if output.optimizer_group_callback: + param_names, param_group = output.optimizer_group_callback(self.model, **defaults) + if param_names and all_param_names & param_names: + raise ValueError('Cannot set one parameter to different param groups') + if param_names and param_group: + all_param_names.update(param_names) + param_groups.extend(param_group) + + decay_parameters = Trainer.get_decay_parameter_names(None, self.model) + param_groups.extend([ + { + 'params': [ + p for n, p in self.model.named_parameters() + if (n in decay_parameters and n not in all_param_names and p.requires_grad) + ], + 'weight_decay': + defaults['weight_decay'], + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if (n not in decay_parameters and n not in all_param_names and p.requires_grad) + ], + 'weight_decay': + 0.0, + }, + ]) + + return param_groups + + @classmethod + def from_pretrained(cls, + model: Union[nn.Module, 'SwiftModel'], + model_id: str = None, + adapter_name: Union[str, List[str], Dict[str, str]] = None, + inference_mode: bool = True, + revision: str = None, + **kwargs): + """Load a set of tuners and corresponding weights by a model_id. + + Args: + model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned, + if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped.. + model_id (`str`): The model_id or a local model dir of tuners to use to tune the model. + adapter_name (`Union[str, List[str], Dict[str, str]]`): The adapter_names saved in the model repo to load. + Default `None`, means load all tuners saved in the model_id + inference_mode (`bool`): Use in the inference mode or not. + revision (`str`): The model revision to use. + **kwargs: + extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved. + Other parameters will be passed to the device_map. + Returns: + The `SwiftModel` instance. + """ + adapters = {} + model_dir = model_id + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir, revision=revision) + if os.path.isfile(model_dir): + raise ValueError(f'Please pass in a local dir or a model id, not a local file: {model_dir}') + extra_state_keys = kwargs.pop('extra_state_keys', None) + if extra_state_keys is None and os.path.isfile(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME)): + with open(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME), 'r', encoding='utf-8') as file: + _json = json.load(file) + extra_state_keys = _json.get('extra_state_keys') + if adapter_name is None: + adapter_name = [ + sub_dir for sub_dir in os.listdir(model_dir) + if os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME)) and sub_dir != cls.EXTRA_STATE_DIR + ] + for _name in adapter_name if isinstance(adapter_name, + list) else [adapter_name] \ + if isinstance(adapter_name, str) else adapter_name.keys(): + sub_folder = os.path.join(model_dir, _name) + config_file = os.path.join(sub_folder, CONFIG_NAME) + + if not os.path.isfile(config_file): + logger.warning(f'{_name} is not a valid tuner') + continue + + with open(config_file, 'r', encoding='utf-8') as file: + json_object = json.load(file) + + if SWIFT_TYPE_KEY not in json_object: + raise ValueError('Mixed using with peft is not allowed now.') + else: + key = _name if not isinstance(adapter_name, dict) else adapter_name[_name] + adapters[key] = SwiftConfig.from_pretrained(sub_folder) + + self = SwiftModel(model, adapters, extra_state_keys, inference_mode, **kwargs) + for _name in adapter_name if isinstance(adapter_name, + list) else [adapter_name] \ + if isinstance(adapter_name, str) else adapter_name.keys(): + _adapter = _name if not isinstance(adapter_name, dict) else adapter_name[_name] + output: SwiftOutput = self.adapters[_adapter] + sub_folder = os.path.join(model_dir, _name) + if output.load_callback: + output.load_callback(self, sub_folder, _adapter) + continue + state_dict = cls.load_state_file(sub_folder) + if state_dict is not None: + if isinstance(adapter_name, dict): + # TODO this logic is fragile! replace `_name` may cause other parts replaced + state_dict = {key.replace(_name, adapter_name[_name]): value for key, value in state_dict.items()} + self.load_state_dict(state_dict, adapter_name=_adapter) + state_dict = cls.load_state_file(os.path.join(model_dir, self.EXTRA_STATE_DIR)) + if state_dict is not None: + self.load_state_dict(state_dict) + return self + + @classmethod + def _prepare_model( + cls, + model: nn.Module, + config: SwiftConfig, + adapter_name: str, + ): + assert (hasattr(config, SWIFT_TYPE_KEY)) + from .mapping import SWIFT_MAPPING + + adapter_cls = SWIFT_MAPPING[config.swift_type][1] + if adapter_cls.has_additional_modules() and not getattr(model, 'model_frozen', False): + for _, p in model.named_parameters(): + p.requires_grad = False + model.model_frozen = True + config.has_additional_modules = adapter_cls.has_additional_modules() + return adapter_cls.prepare_model(model, config, adapter_name) + + def create_or_update_model_card(self, output_dir: str): + """ + Updates or create the model card. + """ + if not os.path.exists(os.path.join(output_dir, 'README.md')): + lines = [] + else: + with open(os.path.join(output_dir, 'README.md'), 'r', encoding='utf-8') as f: + lines = f.readlines() + + quantization_config = None + if hasattr(self.base_model, 'config') and hasattr(self.base_model.config, 'quantization_config'): + if hasattr(self.base_model.config.quantization_config, 'to_dict'): + quantization_config = self.base_model.config.quantization_config.to_dict() + training_config_text = '' + # Adds quantization information if it was used + if quantization_config is not None: + training_config_text += '\nThe following `bitsandbytes` quantization config was used during training:\n' + training_config_text += '\n'.join([f'- {name}: {value}' for name, value in quantization_config.items()]) + training_config_text += '\n' + + training_procedure_heading = '## Training procedure\n' + if training_procedure_heading in lines: + lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) + else: + lines.append(f'{training_procedure_heading}\n{training_config_text}') + + framework_block_heading = '### Framework versions\n' + from swift.version import __version__ + if framework_block_heading in lines: + lines.insert(lines.index(framework_block_heading) + 2, f'- SWIFT {__version__}\n') + else: + lines.append(f'{framework_block_heading}\n\n- SWIFT {__version__}\n') + + base_model_heading = '### Base model information\n' + lines.append(f'{base_model_heading}\n\n- BaseModel Class {self.base_model.__class__.__name__}\n') + + # write the lines back to README.md + with open(os.path.join(output_dir, 'README.md'), 'w', encoding='utf-8') as f: + f.writelines(lines) + + def add_weighted_adapter( + self, + adapters, + weights, + adapter_name, + combination_type='svd', + svd_rank=None, + svd_clamp=None, + svd_full_matrices=True, + svd_driver=None, + density=None, + majority_sign_method: Literal['total', 'frequency'] = 'total', + ): + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + svd_driver (`str`, *optional*): + Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be + one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd` + documentation. Defaults to None. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + from swift.tuners.lora import LoraModel + lora_model = LoraModel(self.model, None, '') + lora_model.peft_config = {key: value.config for key, value in self.adapters.items()} + from peft.tuners.lora import LoraLayer + lora_model.targeted_module_names = [ + key for key, value in self.model.named_modules() if isinstance(value, LoraLayer) + ] + lora_model.active_adapter = self.active_adapters + lora_model.add_weighted_adapter( + adapters=adapters, + weights=weights, + adapter_name=adapter_name, + combination_type=combination_type, + svd_rank=svd_rank, + svd_clamp=svd_clamp, + svd_full_matrices=svd_full_matrices, + svd_driver=svd_driver, + density=density, + majority_sign_method=majority_sign_method, + ) + + def state_dict_callback(state_dict, adapter_name, cfg): + from swift.tuners.lora_layers import lora_state_dict + return lora_state_dict(state_dict, adapter_name, cfg.bias) + + def mark_trainable_callback(model, cfg): + from swift.tuners.lora_layers import mark_lora_as_trainable + mark_lora_as_trainable(model, adapter_name, cfg.bias) + + cfg = lora_model.peft_config[adapter_name] + cfg.has_additional_modules = True + self.adapters[adapter_name] = SwiftOutput( + config=cfg, + state_dict_callback=partial(state_dict_callback, cfg=cfg), + mark_trainable_callback=partial(mark_trainable_callback, cfg=cfg), + optimizer_group_callback=None, + ) + + self.set_active_adapters(adapter_name) + + def save_pretrained(self, + save_directory: str, + safe_serialization: bool = False, + adapter_name: Union[str, List[str]] = None, + **kwargs): + """Save the adapters to a local directory. + + Args: + save_directory (`str`): The directory to use. + safe_serialization (`bool`): Use safe tensors to save the weights, default False. + adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `None` to save all. + """ + peft_format = kwargs.pop('peft_format', False) + if os.path.isfile(save_directory): + raise ValueError(f'Provided path ({save_directory}) should be a directory, not a file') + os.makedirs(save_directory, exist_ok=True) + if not self.has_additional_modules: + if hasattr(self.base_model, 'save_pretrained'): + self.base_model.save_pretrained(save_directory, safe_serialization=safe_serialization) + else: + self._save_state_dict(self.base_model.state_dict(), save_directory, safe_serialization) + self.create_or_update_model_card(save_directory) + else: + self.create_or_update_model_card(save_directory) + + adapter_names = adapter_name if isinstance(adapter_name, list) or adapter_name is None else [adapter_name] + + state_dict_kwargs = {} + state_dict = kwargs.get('state_dict') + if state_dict is not None: + state_dict_kwargs['state_dict'] = kwargs['state_dict'] + for adapter_name, output in self.adapters.items(): + if adapter_names is not None and adapter_name not in adapter_names: + continue + + save_to_peft = peft_format and output.config.swift_type == SwiftTuners.LORA + save_to_peft = save_to_peft and output.config.can_be_saved_to_peft() + if peft_format and not save_to_peft: + logger.error('You are using additional lora parameters, which is not compatible with peft,' + 'which is unable to save to peft format.') + output_dir = os.path.join(save_directory, + adapter_name) if adapter_name != 'default' or not save_to_peft else save_directory + + if save_to_peft: + config = output.config.to_peft_config() + config.save_pretrained(output_dir) + else: + output.config.save_pretrained(output_dir) + + if output.save_callback: + output.save_callback(self, output_dir, adapter_name) + continue + + # save only the trainable weights + output_state_dict = self.state_dict( + adapter_name=adapter_name, save_extra_states=False, peft_format=save_to_peft, **state_dict_kwargs) + os.makedirs(output_dir, exist_ok=True) + if output_state_dict and output.config.has_additional_modules: + self._save_state_dict(output_state_dict, output_dir, safe_serialization) + + output_state_dict = self.state_dict(save_extra_states=True, save_adapter=False, **state_dict_kwargs) + if len(output_state_dict) > 0: + if self.has_additional_modules: + os.makedirs(os.path.join(save_directory, self.EXTRA_STATE_DIR), exist_ok=True) + self._save_state_dict(output_state_dict, os.path.join(save_directory, self.EXTRA_STATE_DIR), + safe_serialization) + with open( + os.path.join(save_directory, self.EXTRA_STATE_DIR, CONFIG_NAME), 'w', encoding='utf-8') as file: + json.dump({'extra_state_keys': self.extra_state_keys}, file) + else: + logger.error('Full parameter training, save_extra_states will be ignored') + + if not os.path.exists(os.path.join(save_directory, 'configuration.json')): + with open(os.path.join(save_directory, 'configuration.json'), 'w', encoding='utf-8') as f: + f.write('{}') + + @staticmethod + def _save_state_dict(output_state_dict, save_directory, safe_serialization): + if safe_serialization: + from safetensors.torch import save_file as safe_save_file + safe_save_file( + output_state_dict, os.path.join(save_directory, SAFETENSORS_WEIGHTS_NAME), metadata={'format': 'pt'}) + else: + torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + @contextmanager + def disable_adapter(self): + try: + self.set_active_adapters(adapter_names=[]) + yield + finally: + self.set_active_adapters(adapter_names=self.adapters.keys()) + + def set_active_adapters(self, adapter_names: Union[List[str], str], offload: str = None): + """Set activated adapters + + Args: + adapter_names(`Union[List[str], str]`): The adapters needed to be activated + offload(`str`): Whether to offload the deactivated ones to `cpu` or `meta` device + """ + if not adapter_names: + adapter_names = [] + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + adapter_names = set(adapter_names) + for adapter_name in (adapter_names & set(self.adapters.keys())): + self.activate_adapter(adapter_name) + + for adapter_name in (set(self.adapters.keys()) - adapter_names): + self.deactivate_adapter(adapter_name, offload) + + self.active_adapters = (adapter_names & set(self.adapters.keys())) + + def activate_adapter(self, adapter_name: str): + """Activate one adapter + + Args: + adapter_name(`str`): The adapter needed to be activated + """ + if adapter_name not in self.adapters: + logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, True) + self.active_adapters = self.active_adapters | {adapter_name} + + def deactivate_adapter(self, adapter_name: str, offload: str = None): + """Deactivate one adapter + + Args: + adapter_name(`str`): The adapter needed to be activated + offload(`str`): Whether to offload to `cpu` or `meta` device + """ + if adapter_name not in self.adapters: + logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, False, offload=offload) + self.active_adapters = self.active_adapters - {adapter_name} + + def get_trainable_parameters(self): + """ + Get the content of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in self.base_model.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, 'ds_numel'): + num_params = param.ds_numel + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + return f'trainable params: {trainable_params:,d} || all params: {all_param:,d} ' \ + f'|| trainable%: {100 * trainable_params / all_param:.4f}' \ + '|| cuda memory: ' \ + f'{sum([torch.cuda.memory_allocated(i) for i in range(get_device_count())])/1024/1024/1024:.2f}' \ + 'GiB.' + + +class Swift: + """The Wrapper to use both Peft and Swift tuners.""" + + @staticmethod + def prepare_model(model: Union[nn.Module, SwiftModel], config: Union[SwiftConfig, PeftConfig, + Dict[str, SwiftConfig]], **kwargs): + """Prepare a model by the input config. + + Args: + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. + config(`Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]]`): The config or config dict, can be either + SwiftConfigs or PeftConfigs + **kwargs: + Extra kwargs needed by SwiftModel or PeftModel. + Returns: + The model wrapped by SwiftModel or PeftModel. + """ + + if isinstance(config, (SwiftConfig, dict)): + return SwiftModel(model, config, **kwargs) + else: + return get_peft_model(model, config, **kwargs) + + @staticmethod + def merge_and_unload(model: Union[PeftModel, SwiftModel], **kwargs): + """Merge tuners into the base model and unload them. + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + kwargs: + adapter_name(`Union[str, List[str]]`): The adapter_name to unload, only supported in swift tuners. + + """ + from peft import PeftModel as _PeftModel + if isinstance(model, _PeftModel): + model.merge_and_unload() + elif isinstance(model, SwiftModel): + from swift import LoRAConfig + from swift.tuners import LoRA + adapter_name = kwargs.get('adapter_name', None) + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + for adapter, output in model.adapters.items(): + if isinstance(output.config, LoRAConfig) and (adapter_name is None or adapter in adapter_name): + LoRA.unpatch_lora(model, output.config, adapter) + + @staticmethod + @contextmanager + def grpo_context(model: Union[SwiftModel, torch.nn.Module], processor): + # Save the model and temporarily modify model.model_dir. + if not isinstance(model, SwiftModel): + yield + return + else: + assert len(model.adapters) == 1 + adapter = list(model.adapters.values())[0] + if adapter.config.swift_type == SwiftTuners.LLAMAPRO: + from modelscope.hub.utils.utils import get_cache_dir + temp_dir = tempfile.mkdtemp(dir=get_cache_dir()) + model_dir = model.model_dir + from transformers.integrations import is_deepspeed_zero3_enabled + if is_deepspeed_zero3_enabled(): + raise ValueError('DeepSpeed ZeRO3 not supported for LLaMAPro&GRPO currently.') + model.base_model.save_pretrained(temp_dir) + processor.save_pretrained(temp_dir) + model.model_dir = temp_dir + yield + if adapter.config.swift_type == SwiftTuners.LLAMAPRO: + model.model_dir = model_dir + shutil.rmtree(temp_dir) + + @staticmethod + def merge(model: Union[PeftModel, SwiftModel], **kwargs): + """Merge tuners into the base model, will not unload them. + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + """ + from .lora_layers import LoraLayer, LoRALayer + for sub_module in model.modules(): + if isinstance(sub_module, (LoraLayer, LoRALayer)): + sub_module.merge(**kwargs) + + @staticmethod + def unmerge(model: Union[PeftModel, SwiftModel], **kwargs): + """Unmerge tuners from the base model + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + """ + from .lora_layers import LoraLayer, LoRALayer + for sub_module in model.modules(): + if isinstance(sub_module, (LoraLayer, LoRALayer)): + sub_module.unmerge(**kwargs) + + @staticmethod + def save_to_peft_format(ckpt_dir: str, output_dir: str) -> None: + """Save swift format to peft format + + Args: + ckpt_dir(`str`): Original swift output dir + output_dir(`str`): Converted peft format dir + """ + assert ckpt_dir and output_dir, 'Please pass in valid ckpt_dir and output_dir.' + assert os.path.exists(ckpt_dir), f'ckpt_dir: {ckpt_dir} must exists in local disk.' + if os.path.exists(os.path.join(ckpt_dir, SwiftModel.EXTRA_STATE_DIR)): + raise AssertionError('Cannot transfer to peft format, because you are additional state dicts.') + + adapter_names = [ + sub_dir for sub_dir in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, sub_dir, CONFIG_NAME)) + ] + + def has_custom_content(_json): + if _json.get('swift_type', _json.get('peft_type')) != SwiftTuners.LORA: + logger.warn('Only LoRA can be converted to peft format') + return True + + from swift import LoRAConfig + return not LoRAConfig(**_json).can_be_saved_to_peft() + + for adapter in adapter_names: + with open(os.path.join(ckpt_dir, adapter, CONFIG_NAME), encoding='utf-8') as f: + _json = json.load(f) + if has_custom_content(_json): + raise AssertionError('Cannot transfer to peft format, ' + 'because you have special parameters or adapter types.') + + os.makedirs(output_dir, exist_ok=True) + if ckpt_dir != output_dir: + shutil.copytree(ckpt_dir, output_dir, dirs_exist_ok=True) + + for adapter in adapter_names: + safe_serialization = os.path.isfile(os.path.join(output_dir, adapter, SAFETENSORS_WEIGHTS_NAME)) + state_dict = SwiftModel.load_state_file(os.path.join(output_dir, adapter)) + new_state_dict = {} + for key, value in state_dict.items(): + if not key.startswith('base_model.model.'): + key = 'base_model.model.' + key + key = key.replace(f'lora_A.{adapter}.', 'lora_A.') + key = key.replace(f'lora_B.{adapter}.', 'lora_B.') + key = key.replace(f'lora_embedding_A.{adapter}.', 'lora_embedding_A.') + key = key.replace(f'lora_embedding_B.{adapter}.', 'lora_embedding_B.') + key = key.replace(f'lora_magnitude_vector.{adapter}', 'lora_magnitude_vector') + new_state_dict[key] = value + state_dict = new_state_dict + SwiftModel._save_state_dict(state_dict, os.path.join(output_dir, adapter), safe_serialization) + from swift import LoRAConfig + with open(os.path.join(output_dir, adapter, CONFIG_NAME), encoding='utf-8') as f: + _json = json.load(f) + peft_config = LoRAConfig(**_json).to_peft_config() + peft_config.save_pretrained(os.path.join(output_dir, adapter)) + + if 'default' in adapter_names: + shutil.move(os.path.join(output_dir, 'default', CONFIG_NAME), os.path.join(output_dir, CONFIG_NAME)) + state_dict = SwiftModel.load_state_file(os.path.join(output_dir, 'default')) + safe_serialization = os.path.isfile(os.path.join(output_dir, 'default', SAFETENSORS_WEIGHTS_NAME)) + SwiftModel._save_state_dict(state_dict, output_dir, safe_serialization) + shutil.rmtree(os.path.join(output_dir, 'default')) + + @staticmethod + def from_pretrained(model: Union[nn.Module, SwiftModel, PeftModel], + model_id: str = None, + adapter_name: Union[str, List[str], Dict[str, str]] = None, + revision: str = None, + **kwargs): + """Prepare a model by a model_id in the ModelScope hub or a local dir. + + Args: + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. + model_id(`str`): The model id of the modelhub or a local dir containing the configs/weights. + adapter_name(`str`, `optional`): The adapter_name to use. + revision(`str`, `optional`): The model revision if the model_id is a model id of the modelhub. + **kwargs: + Extra kwargs needed by ``SwiftModel.from_pretrained`` or ``PeftModel.from_pretrained``. + Returns: + The model wrapped by SwiftModel or PeftModel. + """ + if not os.path.exists(model_id): + model_id = snapshot_download(model_id, revision=revision) + is_peft_model = False + if os.path.exists(os.path.join(model_id, CONFIG_NAME)): + with open(os.path.join(model_id, CONFIG_NAME), 'r', encoding='utf-8') as f: + _json = json.load(f) + is_peft_model = SWIFT_TYPE_KEY not in _json + + _name = adapter_name if isinstance( + adapter_name, str) or adapter_name is None else adapter_name[0] \ + if isinstance(adapter_name, list) else list(adapter_name.keys())[0] + _name = _name or '' + if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)): + with open(os.path.join(model_id, _name, CONFIG_NAME), 'r', encoding='utf-8') as f: + _json = json.load(f) + is_peft_model = SWIFT_TYPE_KEY not in _json and 'extra_state_keys' not in _json + if is_peft_model: + + def load_peft_model(_model, _adapter_name, _new_name=None): + if not _new_name: + _new_name = _adapter_name + import peft + if not isinstance(_model, peft.PeftModel): + return PeftModel.from_pretrained( + _model, + os.path.join(model_id, _adapter_name) if _adapter_name != 'default' + and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, + revision=revision, + adapter_name=_new_name, + **kwargs) + else: + _model.load_adapter( + os.path.join(model_id, _adapter_name) if _adapter_name != 'default' + and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, _new_name) + return _model + + if not adapter_name: + peft_model = load_peft_model(model, 'default') + for _dir in os.listdir(model_id): + if os.path.isdir(os.path.join(model_id, _dir)) and \ + os.path.exists(os.path.join(model_id, _dir, CONFIG_NAME)): + peft_model = load_peft_model(peft_model, _dir) + elif isinstance(adapter_name, str): + return load_peft_model(model, adapter_name) + elif isinstance(adapter_name, list): + peft_model = model + for name in adapter_name: + peft_model = load_peft_model(peft_model, name) + else: + peft_model = model + for key, value in adapter_name.items(): + peft_model = load_peft_model(peft_model, key, value) + return peft_model + else: + return SwiftModel.from_pretrained(model, model_id, revision=revision, adapter_name=adapter_name, **kwargs) diff --git a/swift/tuners/llamapro.py b/swift/tuners/llamapro.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec6d254fd743750d1e7914d00a08e6ea5fc63be --- /dev/null +++ b/swift/tuners/llamapro.py @@ -0,0 +1,233 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from copy import deepcopy +from dataclasses import dataclass, field, fields +from typing import Optional + +import torch +from torch import nn + +from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys +from swift.utils.logger import get_logger +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class LLaMAProConfig(SwiftConfig): + """ + The configuration class for the LLaMAPro module. + + See https://arxiv.org/abs/2401.02415 + + Args: + model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually + modified. + num_new_blocks(`int`): How many new blocks need to be added + num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each + single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers. + """ + model_type: str = field( + default=None, metadata={ + 'choices': list(MODEL_ARCH_MAPPING.keys()), + }) + + num_new_blocks: int = None + + num_groups: Optional[int] = None + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.LLAMAPRO + + +class LLaMAPro(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `LLaMAProConfig`""" + num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers') + if num_hidden_layers is None: + num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers') + assert num_hidden_layers is not None, 'Cannot find num of layers config' + assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \ + f'should be divided by {config.num_new_blocks}' + if config.num_groups is None: + config.num_groups = config.num_new_blocks + + # the except block will change the model_type, this will cause `model not found` error + # when using internvl + origin_model_type = config.model_type + model_type = origin_model_type + num_stride = num_hidden_layers // config.num_groups + try: + module_list = LLaMAPro._find_module_list(config, model) + except AssertionError as e: + model_type = LLaMAPro.search_correct_model_type(model) + if model_type is None: + language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model + if language_model_name: + if isinstance(language_model_name, str): + language_model_name = [language_model_name] + language_model = model.get_submodule(language_model_name[0]) + model_type = LLaMAPro.search_correct_model_type(language_model) + if model_type: + model = language_model + + if model_type: + config.model_type = model_type + module_list = LLaMAPro._find_module_list(config, model) + else: + raise e + + new_module_list = nn.ModuleList() + new_module_idx = [] + for idx, module in enumerate(module_list): + new_module_list.append(module) + if (idx + 1) % num_stride == 0: + new_module = deepcopy(module) + ActivationMixin.mark_all_sub_modules_as_plugin(new_module) + new_module_list.append(new_module) + new_module_idx.append(idx + 1 + len(new_module_idx)) + + LLaMAPro._update_module_weight(config, new_module_list, new_module_idx) + LLaMAPro._update_module_attr(config, new_module_list) + model.config.num_hidden_layers = len(new_module_list) + LLaMAPro._set_module_list(config, model, new_module_list) + + def activate_module(activate: bool): + if activate: + LLaMAPro._update_module_attr(config, new_module_list) + LLaMAPro._set_module_list(config, model, new_module_list) + else: + LLaMAPro._update_module_attr(config, module_list) + LLaMAPro._set_module_list(config, model, module_list) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) + new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] + return { + key: value + for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list]) + } + + def mark_trainable_callback(model): + model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) + new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] + for name, parameter in model.named_parameters(): + parameter: nn.Parameter + if any([m_part in name for m_part in new_module_list]): + parameter.requires_grad = True + + config.model_type = origin_model_type + model.activate_module = activate_module + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def _update_module_attr(config: LLaMAProConfig, module_list): + model_type = config.model_type + model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) + attention = model_key_mapping.attention + attention = attention.split('{}.')[1] + if model_type == 'phi3-small': + raise ValueError('phi3-small does not support llamapro currently') + if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion', + 'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'): + for idx, module in enumerate(module_list): + try: + getattr(module, attention).layer_idx = idx + except AttributeError: + getattr(module, 'cross_attn').layer_idx = idx + elif model_type in ('chatglm', 'glm4'): + for idx, module in enumerate(module_list): + getattr(module, attention).layer_number = idx + elif model_type in ('phi2', ): + for idx, module in enumerate(module_list): + getattr(module, attention).block_idx = idx + else: + for idx, module in enumerate(module_list): + attrs = [ + attr for attr in dir(getattr(module_list[0], attention)) + if attr in ('layer_idx', 'layer_number', 'block_idx') + ] + assert len(attrs) <= 1 + if attrs: + setattr(getattr(module, attention), attrs[0], idx) + else: + logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,' + f'please give us a feedback.') + + @classmethod + def get_model_key_mapping(cls, model_type, config) -> ModelKeys: + + model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config) + assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \ + 'LLaMAPro only support models with o_proj and down_proj components.' + return model_key_mapping + + @classmethod + def search_correct_model_type(cls, module: nn.Module): + for arch_name, arch_type in MODEL_ARCH_MAPPING.items(): + arch_type: ModelKeys + if getattr(arch_type, 'module_list') is None: + # Need to be a LLM arch + continue + + matched = True + for f in fields(arch_type): + arch_str = getattr(arch_type, f.name) + if f.name == 'arch_name' or arch_str is None: + continue + + arch_str = arch_str.replace('{}', '0') + try: + sub_module = module.get_submodule(arch_str) + if sub_module is None: + matched = False + except AttributeError: + matched = False + + if not matched: + break + + if matched: + return arch_name + + @staticmethod + def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx): + model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) + o_proj = model_key_mapping.o_proj.split('{}.')[1] + down_proj = model_key_mapping.down_proj.split('{}.')[1] + + for idx, module in enumerate(module_list): + if idx not in new_module_idx: + continue + _o_proj: nn.Linear = module.get_submodule(o_proj) + _down_proj: nn.Linear = module.get_submodule(down_proj) + _o_proj.weight.data = torch.zeros_like(_o_proj.weight.data) + _down_proj.weight.data = torch.zeros_like(_down_proj.weight.data) + if hasattr(_o_proj, 'bias') and _o_proj.bias is not None: + _o_proj.bias.data = torch.zeros_like(_o_proj.bias) + if hasattr(_down_proj, 'bias') and _down_proj.bias is not None: + _down_proj.bias.data = torch.zeros_like(_down_proj.bias) + + @staticmethod + def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList): + model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) + idx = model_key_mapping.module_list.rfind('.') + parent = module.get_submodule(model_key_mapping.module_list[:idx]) + setattr(parent, model_key_mapping.module_list[idx + 1:], module_list) + + @staticmethod + def _find_module_list(config, module: nn.Module) -> nn.ModuleList: + model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) + return module.get_submodule(model_key_mapping.module_list) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + module.activate_module(activate) + + @staticmethod + def has_additional_modules(): + return True diff --git a/swift/tuners/longlora/__init__.py b/swift/tuners/longlora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/swift/tuners/longlora/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc b/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cb20ea79ddff3447945e8f58d2d2ec5b394fcf6 Binary files /dev/null and b/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc b/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d22a0f14ce27e7780dc8dcda96504d484a60b41 Binary files /dev/null and b/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc differ diff --git a/swift/tuners/longlora/llama.py b/swift/tuners/longlora/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..6c54abcc05c1b4a1d3c998cd9a1ed365ea08486f --- /dev/null +++ b/swift/tuners/longlora/llama.py @@ -0,0 +1,409 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from dvlab-research/LongLoRA. + +import math +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache, StaticCache +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +from swift.utils import get_logger + +logger = get_logger() + + +def _preprocess_qkv_fa2(attn_module, query_states, key_states, value_states, attention_mask): + if attn_module.training: + bsz, q_len = query_states.shape[:2] + group_size = int(q_len * attn_module.config.group_size_ratio) + if q_len % group_size != 0: + raise ValueError(f'The sequence length {q_len} should' + f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}') + + num_group = q_len // group_size + + def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, :, num_heads // 2:] = qkv[:, :, num_heads // 2:].roll(-group_size // 2, dims=1) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim) + return qkv + + query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + if attention_mask is not None: + attention_mask = attention_mask[:, :group_size].repeat(num_group, 1) + + return query_states, key_states, value_states, attention_mask + + +def _preprocess_qkv(attn_module, query_states, key_states, value_states, attention_mask): + if attn_module.training: + bsz, _, q_len = query_states.shape[:3] + group_size = int(q_len * attn_module.config.group_size_ratio) + if q_len % group_size != 0: + raise ValueError(f'The sequence length {q_len} should' + f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}') + + num_group = q_len // group_size + + def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2) + qkv = qkv.transpose(1, 2) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim) + return qkv.transpose(1, 2) + + query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) + + return query_states, key_states, value_states, attention_mask + + +def _postprocess_qkv(attn_module, attn_output, q_len): + if attn_module.training: + group_size = int(q_len * attn_module.config.group_size_ratio) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim) + # shift back + attn_output_clone = attn_output.clone() + attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll( + group_size // 2, dims=1) + attn_output = attn_output_clone + return attn_output.transpose(1, 2) + + +def _postprocess_qkv_fa2(attn_module, attn_output, q_len): + if attn_module.training: + group_size = int(q_len * attn_module.config.group_size_ratio) + attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim) + attn_output_clone = attn_output.clone() + # shift back + attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll( + group_size // 2, dims=1) + attn_output = attn_output_clone + return attn_output + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def eager_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # patch position rolling + query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states, + attention_mask) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + # patch position unrolling + attn_output = _postprocess_qkv(self, attn_output, q_len) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def fa2_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` ' + 'make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers' + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # patch position rolling + query_states, key_states, value_states, attention_mask = _preprocess_qkv_fa2( + self, query_states, key_states, value_states, attention_mask) + from transformers.modeling_flash_attention_utils import _flash_attention_forward + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, 'sliding_window', None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + # patch position unrolling + attn_output = _postprocess_qkv_fa2(self, attn_output, q_len) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + + if query_states.device.type == 'cuda' and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + # patch position rolling + query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states, + causal_mask) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + # patch position unrolling + attn_output = _postprocess_qkv(self, attn_output, q_len) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def replace_llama_attn(model: nn.Module): + layers = None + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + layers = module + break + assert layers is not None + for idx, m in enumerate(layers): + if model.config._attn_implementation == 'flash_attention_2': + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + logger.warn( + 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' # noqa + 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593') + m.self_attn.forward = MethodType(fa2_forward, m.self_attn) + elif model.config._attn_implementation == 'eager': + m.self_attn.forward = MethodType(eager_forward, m.self_attn) + elif model.config._attn_implementation == 'sdpa': + m.self_attn.forward = MethodType(sdpa_forward, m.self_attn) diff --git a/swift/tuners/longlora/longlora.py b/swift/tuners/longlora/longlora.py new file mode 100644 index 0000000000000000000000000000000000000000..427837b6eef17ad16c76e638d9fbc513baf2d6da --- /dev/null +++ b/swift/tuners/longlora/longlora.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from dvlab-research/LongLoRA. +import re +from dataclasses import dataclass, field +from typing import List, Tuple, Union + +import torch.nn as nn + +from swift.tuners.lora import lora_state_dict, mark_lora_as_trainable +from swift.tuners.lora_layers import LoraModel +from .. import LoRA, LoRAConfig, SwiftOutput + + +class LongLoRAModelType: + LLAMA = 'llama' + + +@dataclass +class LongLoRAConfig(LoRAConfig): + """ + The Config for the LongLoRA adapter. + LongLoRA:[Efficient Fine-tuning of Long-Context Large Language Models](https://arxiv.org/abs/2309.12307) + This adapter uses S2-attention to shorten the attention window for long context training scenarios. + Args: + embedder_and_normalizer: LongLoRA allows the embedder and normalizer to be trainable, this parameter specifies + the names of the embedders and normalizers. + model_type: The model type, now support llama only + group_size_ratio: The group size window ratio of the sequence length. + Note: The sequence length should be split to smaller sequences by the ratio. + """ + + embedder_and_normalizer: Union[str, List[str], Tuple[str]] = field( + default=('embed', 'norm'), + metadata={ + 'help': 'The names of embedder and normalizer, regex format if is a str, else will match with sub sequences' + }) + + model_type: str = field(default=None, metadata={'help': 'The model type, now only support `llama` structure.'}) + + group_size_ratio: float = field(default=0.25, metadata={'help': 'The S2 attention group ratio'}) + + def __post_init__(self): + from swift.tuners.mapping import SwiftTuners + self.swift_type = SwiftTuners.LONGLORA + + +class LongLoRA(LoRA): + + @staticmethod + def prepare_model(model: nn.Module, config: LongLoRAConfig, adapter_name: str): + """Prepare a model with `LongLoRAConfig`""" + LoraModel(model, config, adapter_name) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + _state_dict = lora_state_dict(state_dict, adapter_name, config.bias) + for name, value in state_dict.items(): + if isinstance(config.embedder_and_normalizer, str): + target_module_found = re.fullmatch(config.embedder_and_normalizer, name) + else: + target_module_found = any(target_key in name for target_key in config.embedder_and_normalizer) + if target_module_found and name not in _state_dict: # noqa + _state_dict[name] = value + return _state_dict + + def mark_trainable_callback(model): + mark_lora_as_trainable(model, adapter_name, config.bias) + mark_embedding_normalizer_as_trainable(model, config.embedder_and_normalizer) + + if config.model_type == LongLoRAModelType.LLAMA: + from .llama import replace_llama_attn + replace_llama_attn(model) + # only support code base from transformers + model.config.group_size_ratio = config.group_size_ratio + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + +def mark_embedding_normalizer_as_trainable(model: nn.Module, extra_parameters: Union[str, List[str], + Tuple[str]]) -> None: + for name, sub_module in model.named_parameters(): + if isinstance(extra_parameters, str): + target_module_found = re.fullmatch(extra_parameters, name) + else: + target_module_found = any(target_key in name for target_key in extra_parameters) + if target_module_found: # noqa + sub_module.requires_grad = True diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b36e5df392d41c24a2e99f426a062ef018412dec --- /dev/null +++ b/swift/tuners/lora.py @@ -0,0 +1,193 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +from dataclasses import asdict, dataclass, field +from functools import reduce + +import peft +import torch +from packaging import version +from transformers import Trainer + +from .lora_layers import * # noqa +from .utils import SwiftAdapter, SwiftConfig, SwiftOutput, set_adapter + +logger = get_logger() + + +@dataclass +class LoRAConfig(LoraConfig, SwiftConfig): + """ + The configuration class for the loRA module. + + Args: + use_qa_lora(bool): Use + QA-LoRA:[Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2309.14717) + instead of LoRA. QA-LoRA only supports AutoGPTQ quantized models. + Deprecated, do not use this argument. + lora_dtype(str): The dtype for all lora modules, supported values are `fp32`, `fp16`, `bf16`. + Default value is `None`, which means follow the dtype of original module's weight. + lorap_lr_ratio(float): The lr_ratio argument for [LoRA+](https://arxiv.org/abs/2402.12354) + """ + + use_qa_lora: bool = field( + default=False, metadata={'help': 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'}) + + use_merged_linear: bool = field(default=False, metadata={'help': 'Use merged Linear'}) + + enable_lora: List[bool] = field( + default=None, metadata={'help': 'The modules need to be turned on when using the merged linear layer'}) + + lora_dtype: Optional[str] = field( + default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'}) + + lorap_lr_ratio: float = field(default=2.0**4, metadata={'help': 'The lr ratio of lora_B in lora+'}) + + lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'}) + + def __post_init__(self): + super().__post_init__() + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.LORA + + def can_be_saved_to_peft(self) -> bool: + if self.use_qa_lora or self.use_merged_linear: + logger.warn('QA-LoRA and MergedLinear cannot be saved to peft format') + return False + return True + + def to_peft_config(self) -> LoraConfig: + _dict = asdict(self) + _dict.pop('use_qa_lora', None) + _dict.pop('enable_lora', None) + _dict.pop('lora_dtype', None) + _dict.pop('use_merged_linear', None) + _dict['peft_type'] = _dict['swift_type'] + _dict.pop('swift_type', None) + _dict.pop('lr_ratio', None) + _dict.pop('model_key_mapping', None) + return LoraConfig(**_dict) + + def save_pretrained(self, save_directory: str, **kwargs) -> None: + super(peft.LoraConfig, self).save_pretrained(save_directory, **kwargs) + + +class LoRA(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str): + assert not config.use_qa_lora, 'Do not use qa-lora' + if config.use_qa_lora: + auto_gptq_config = get_quantization_config(model, method='gptq') + if auto_gptq_config: + config.group_size = getattr(auto_gptq_config, 'group_size', None) + LoraModel(model, config, adapter_name) + + def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs): + return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias) + + def mark_trainable_callback(model, cfg=None): + mark_lora_as_trainable(model, adapter_name, cfg.bias if cfg else config.bias) + + def optimizer_group_callback(model, **defaults): + if config.lorap_lr_ratio is None: + return None, None + + def get_module(name): + parent_idx = 2 if 'lora' in name else 1 + module_names = name.split(sep='.')[:-parent_idx] + module = reduce(getattr, module_names, model) + return module + + all_params = set() + param_groups = { + 'groupA': {}, + 'groupB': {}, + 'groupB_no_decay': {}, + 'embedding': {}, + } + + decay_parameters = Trainer.get_decay_parameter_names(None, model) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + module = get_module(name) + if isinstance(module, Embedding): + param_groups['embedding'][name] = param + elif 'lora_B' in name or param.ndim == 1: + if name in decay_parameters: + param_groups['groupB'][name] = param + else: + param_groups['groupB_no_decay'][name] = param + else: + param_groups['groupA'][name] = param + all_params.add(name) + + lr = defaults['lr'] + weight_decay = defaults.get('weight_decay', 0.0) + + param_groups = [ + { + 'params': list(param_groups['groupA'].values()), + 'weight_decay': weight_decay, + 'lr': lr, + }, + { + 'params': list(param_groups['embedding'].values()), + 'weight_decay': weight_decay, + 'lr': config.lorap_emb_lr, + }, + { + 'params': list(param_groups['groupB'].values()), + 'weight_decay': weight_decay, + 'lr': lr * config.lorap_lr_ratio, + }, + { + 'params': list(param_groups['groupB_no_decay'].values()), + 'weight_decay': 0.0, + 'lr': lr * config.lorap_lr_ratio, + }, + ] + return all_params, param_groups + + return SwiftOutput( + config=config, + state_dict_callback=state_dict_callback, + mark_trainable_callback=mark_trainable_callback, + optimizer_group_callback=optimizer_group_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + set_adapter(module, adapter_name, activate, offload) + for sub_module in module.modules(): + if isinstance(sub_module, (LoraLayer, LoRALayer)): + sub_module.set_activation(adapter_name, activate) + if hasattr(sub_module, 'save_memory'): + sub_module.save_memory(adapter_name, activate, offload) + + @staticmethod + def unpatch_lora(model, config: LoRAConfig, adapter_name: str): + """Unpatch lora modules and merge the weights to original modules. + + LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network. + 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021) + See https://arxiv.org/abs/2106.09685 + + Args: + model(`torch.nn.Module`): The model called with `tune` function. + config(`LoRAConfig`): The `LoRAConfig` to use. Deprecated + adapter_name(`str`): The adapter name + """ + if not config.use_merged_linear: + if version.parse(peft.__version__) < version.parse('0.6.3'): + logger.info('All adapters will be merged.') + LoraModel(model, None, '').merge_and_unload() + else: + LoraModel(model, None, '').merge_and_unload(adapter_names=[adapter_name]) + else: + for name, sub_module in model.named_modules(): + if isinstance(sub_module, MergedLinear): + sub_module.merge() + parent = model.get_submodule('.'.join(name.split('.')[:-1])) + target_name = name.split('.')[-1] + setattr(parent, target_name, sub_module.base_layer) diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f681644a3829fbbf8961fe02819a567165fbaad4 --- /dev/null +++ b/swift/tuners/lora_layers.py @@ -0,0 +1,673 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +import math +import re +import warnings +from itertools import chain +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.lora import Conv2d as _Conv2d +from peft.tuners.lora import Embedding as _Embedding +from peft.tuners.lora import Linear as _Linear +from peft.tuners.lora import LoraLayer +from peft.tuners.lora import LoraModel as _LoraModel +from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import _get_submodules, get_quantization_config +from transformers import Conv1D + +from swift.utils import get_logger +from .peft import LoraConfig +from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter + +logger = get_logger() +dispatchers = [] + + +class LoRAActivationMixin(ActivationMixin): + + @property + def active_adapters(self): + return self.get_activated_adapters() + + @property + def active_adapter(self) -> str: + return self.get_activated_adapters() + + def set_adapter(self, adapter_names, offload=None): + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Deactivate grads on the inactive adapter and activate grads on the active adapter + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key in adapter_names: + self.set_activation(key, True) + layer.requires_grad_(True) + SwiftAdapter.save_memory(layer, key, self.module_key, True) + else: + self.set_activation(key, False) + layer.requires_grad_(False) + SwiftAdapter.save_memory(layer, key, self.module_key, False, offload=offload) + + def save_memory(self, adapter_name, activate, offload=None): + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key == adapter_name: + if activate: + SwiftAdapter.save_memory(layer, layer_name + '.' + key, self.module_key, True) + else: + SwiftAdapter.save_memory(layer, layer_name + '.' + key, self.module_key, False, offload=offload) + + def merge(self, *args, **kwargs): + if not self.unique_thread: + raise AssertionError('Merge is unsupported in multiple thread, ' + 'please set `USE_UNIQUE_THREAD=1` in env variable to merge LoRA.') + return super().merge(*args, **kwargs) + + +if is_bnb_available(): + import bitsandbytes as bnb + from peft.tuners.lora.bnb import Linear8bitLt as _Linear8bitLt + + class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt): + + def __init__( + self, + *args, + module_key: str, + **kwargs, + ): + super(Linear8bitLt, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, module_key: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_8bit = kwargs.get('loaded_in_8bit', False) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update({ + 'has_fp16_weights': target.state.has_fp16_weights, + 'threshold': target.state.threshold, + 'index': target.index, + }) + new_module = Linear8bitLt(target, adapter_name, module_key=module_key, **eightbit_kwargs) + + return new_module + + dispatchers.append(dispatch_bnb_8bit) + +if is_bnb_4bit_available(): + from peft.tuners.lora.bnb import Linear4bit as _Linear4bit + + class Linear4bit(LoRAActivationMixin, _Linear4bit): + + def __init__( + self, + *args, + module_key: str, + **kwargs, + ): + super(Linear4bit, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, module_key: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_4bit = kwargs.get('loaded_in_4bit', False) + if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update({ + 'compute_dtype': target_base_layer.compute_dtype, + 'compress_statistics': target_base_layer.weight.compress_statistics, + 'quant_type': target_base_layer.weight.quant_type, + }) + new_module = Linear4bit(target, adapter_name, module_key=module_key, **fourbit_kwargs) + + return new_module + + dispatchers.append(dispatch_bnb_4bit) + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + lora_config: LoraConfig, + module_key: str, + **kwargs, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Embedding): + embedding_kwargs = kwargs.copy() + embedding_kwargs.pop('fan_in_fan_out', None) + embedding_kwargs.update(lora_config.loftq_config) + new_module = Embedding(target, adapter_name, module_key=module_key, **embedding_kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + kwargs.update(lora_config.loftq_config) + new_module = Conv2d(target, adapter_name, module_key=module_key, **kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if target_base_layer.__class__.__name__ == 'NonDynamicallyQuantizableLinear': + # Fix issue: https://github.com/modelscope/swift/issues/342 + return None + if kwargs['fan_in_fan_out']: + warnings.warn('fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. ' + 'Setting fan_in_fan_out to False.') + kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, module_key=module_key, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs['fan_in_fan_out']: + warnings.warn('fan_in_fan_out is set to False but the target module is `Conv1D`. ' + 'Setting fan_in_fan_out to True.') + kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, module_key=module_key, **kwargs) + + return new_module + + +dispatchers.append(dispatch_default) + + +class Embedding(LoRAActivationMixin, _Embedding): + + def __init__( + self, + *args, + module_key: str, + **kwargs, + ) -> None: + super(Embedding, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + +class Linear(LoRAActivationMixin, _Linear): + + def __init__(self, *args, module_key: str, **kwargs): + super(Linear, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + +class Conv2d(LoRAActivationMixin, _Conv2d): + + def __init__(self, *args, module_key: str, **kwargs): + super(Conv2d, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + +class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear): + + def __init__(self, *args, module_key: str, **kwargs): + super(LoraParallelLinear, self).__init__(module_key) + self.set_activation(args[1], True) + super(ActivationMixin, self).__init__(*args, **kwargs) + + +class LoraModel(_LoraModel): + + prefix: str = 'lora_' + + def __init__(self, model, config, adapter_name): + if config is not None: + super().__init__(model, config, adapter_name) + else: + nn.Module.__init__(self) + self.model = model + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == 'none': + continue + + if bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, 'bias') and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f'Requested bias: {bias}, is not implemented.') + + def inject_adapter(self, + model: nn.Module, + adapter_name: str, + autocast_adapter_dtype: bool = True, + low_cpu_mem_usage: bool = False): + r""" + Override code: + 1. ModulesToSaveWrapper construction method: add module_key=key argument to offload to cpu + """ + peft_config = self.peft_config[adapter_name] + # Note: If possible, all checks should be performed *at the start of this method*. + # This way, we can raise early if something goes wrong, without leaving the model + # in a bad (half-initialized) state. + self._check_new_adapter_config(peft_config) + + is_target_modules_in_base_model = False + key_list = [key for key, _ in model.named_modules()] + + _check_for_modules_to_save = getattr(peft_config, 'modules_to_save', None) is not None + _has_modules_to_save = False + + model_config = getattr(model, 'config', {'model_type': 'custom'}) + if hasattr(model_config, 'to_dict'): + model_config = model_config.to_dict() + + peft_config = self._prepare_adapter_config(peft_config, model_config) + + from peft.tuners.tuners_utils import _maybe_include_all_linear_layers + try: + from peft.utils.constants import DUMMY_TARGET_MODULES + except ImportError: # compat with peft==0.11.* + DUMMY_TARGET_MODULES = 'dummy-target-modules' + if getattr(peft_config, 'target_modules', None) == DUMMY_TARGET_MODULES: + # dummy adapter, we allow not matching any module + key_list = [] + is_target_modules_in_base_model = True + # update peft_config.target_modules if required + peft_config = _maybe_include_all_linear_layers(peft_config, model) + self._prepare_model(peft_config, model) + + for key in key_list: + if '_part_' in key or not key: + # Avoid lora conflict with part tuner + continue + # Check for modules_to_save in case + if _check_for_modules_to_save and any( + key.endswith(f'{module_to_save}') for module_to_save in peft_config.modules_to_save): + # Optionally set the modules to save + parent, target, target_name = _get_submodules(model, key) + + if not isinstance(target, ModulesToSaveWrapper): + new_module = ModulesToSaveWrapper(target, adapter_name=adapter_name, module_key=key) + setattr(parent, target_name, new_module) + else: + target.update(adapter_name) + + _has_modules_to_save = True + continue + + if not self._check_target_module_exists(peft_config, key): + continue + + self.targeted_module_names.append(key) + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(model, key) + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) + + if not is_target_modules_in_base_model and hasattr(peft_config, 'target_modules'): + raise ValueError(f'Target modules {peft_config.target_modules} not found in the base model. ' + f'Please check the target modules and try again.') + + self._mark_only_adapters_as_trainable(self.model) + + if self.peft_config[adapter_name].inference_mode: + for n, p in self.model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + if _has_modules_to_save: + if not hasattr(model, 'modules_to_save'): + model.modules_to_save = set(peft_config.modules_to_save) + else: + model.modules_to_save.update(set(peft_config.modules_to_save)) + + def _convert_dtype(self, target: nn.Module, lora_dtype: str): + if lora_dtype == 'float32': + torch_dtype = torch.float32 + elif lora_dtype == 'float16': + torch_dtype = torch.float16 + elif lora_dtype == 'bfloat16': + torch_dtype = torch.bfloat16 + else: + torch_dtype = None + + if torch_dtype is not None: + if hasattr(target, 'lora_A'): + target.lora_A.to(torch_dtype) + target.lora_B.to(torch_dtype) + if hasattr(target, 'lora_embedding_A'): + target.lora_embedding_A.to(torch_dtype) + target.lora_embedding_B.to(torch_dtype) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + """ + Override code: + 1. Import bnb from upper code + 2. Support dtype converting + 3. Support skipping NonDynamicallyQuantizableLinear + 4. Add current_key argument to _create_new_module + 5. Use Class type defined here + 6. Allow new_module being None + """ + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf'.*\.{key}$', current_key), pattern_keys), current_key) + r = lora_config.rank_pattern.get(target_name_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha) + + kwargs = { + 'r': r, + 'lora_alpha': alpha, + 'lora_dropout': lora_config.lora_dropout, + 'fan_in_fan_out': lora_config.fan_in_fan_out, + 'init_lora_weights': lora_config.init_lora_weights, + 'use_rslora': lora_config.use_rslora, + 'use_dora': lora_config.use_dora, + 'loaded_in_8bit': getattr(self.model, 'is_loaded_in_8bit', False), + 'loaded_in_4bit': getattr(self.model, 'is_loaded_in_4bit', False), + } + # compat with peft==0.11.* + if hasattr(lora_config, 'runtime_config'): + kwargs['ephemeral_gpu_offload'] = lora_config.runtime_config.ephemeral_gpu_offload + + quant_methods = ['gptq', 'aqlm', 'awq'] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f'{quant_method}_quantization_config'] = quantization_config + + # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it + from peft.tuners.adalora import AdaLoraLayer + + if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer): + if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear': + # Fix issue: https://github.com/modelscope/swift/issues/342 + return + target.update_layer( + adapter_name, + r, + lora_alpha=alpha, + lora_dropout=lora_config.lora_dropout, + init_lora_weights=lora_config.init_lora_weights, + use_rslora=lora_config.use_rslora, + use_dora=lora_config.use_dora, + ) + self._convert_dtype(target, lora_config.lora_dtype) + ActivationMixin.mark_all_sub_modules_as_plugin(target) + else: + new_module = self._create_new_module(lora_config, adapter_name, target, current_key=current_key, **kwargs) + if new_module is not None: + ActivationMixin.mark_all_sub_modules_as_plugin(new_module) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + self._convert_dtype(new_module, lora_config.lora_dtype) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, 'base_layer'): + child = child.base_layer + + if not hasattr(new_module, 'base_layer'): + if hasattr(new_module, 'W_q'): # HQQ + new_module.W_q = child.W_q + else: + new_module.weight = child.weight + if hasattr(child, 'bias'): + new_module.bias = child.bias + + if getattr(child, 'state', None) is not None: + if hasattr(new_module, 'base_layer'): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + meta = torch.device('meta') + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ('ranknum' in name): + weight = ( + child.qweight if hasattr(child, 'qweight') else child.W_q if hasattr(child, 'W_q') else + child.weight if hasattr(child, 'weight') else next(child.parameters())) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + """ + Override code: + 1. Support current_key argument + 2. Support MergedLinear + 3. Support skipping NonDynamicallyQuantizableLinear(Move to dispatcher) + 4. Use Class type defined here(Move to dispatcher) + 5. return None instead of raising error when target type not found + """ + # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, + # because the first match is always used. Therefore, the default layers should be checked last. + current_key = kwargs.pop('current_key') + new_module = None + if lora_config.use_qa_lora: + kwargs['use_qa_lora'] = True + kwargs['group_size'] = lora_config.group_size + if lora_config.use_merged_linear: + bias = kwargs.pop('bias', False) + new_module = MergedLinear( + adapter_name, current_key, target, bias=bias, enable_lora=lora_config.enable_lora, **kwargs) + else: + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, lora_config=lora_config, module_key=current_key, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + logger.debug( + f'Target module {target} is not supported. Currently, only the following modules are supported: ' + '`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`.') + new_module = None + + return new_module + + +class LoRALayer(ActivationMixin): + + def __init__( + self, + adapter_name: str, + module_key: str, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + super().__init__(module_key) + self.adapter_name = adapter_name + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + if not self._unique_thread: + self.merge_weights = False + + +class MergedLinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__(self, + adapter_name: str, + module_key: str, + base_layer: nn.Linear, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + enable_lora: List[bool] = [False], + fan_in_fan_out: bool = False, + merge_weights: bool = True, + bias: bool = True, + device=None, + dtype=None, + **kwargs): + nn.Linear.__init__(self, base_layer.in_features, base_layer.out_features, bias=bias, device=device, dtype=dtype) + LoRALayer.__init__( + self, + adapter_name, + module_key, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert base_layer.out_features % len(enable_lora) == 0, \ + 'The length of enable_lora must divide out_features' + self.enable_lora = enable_lora + self.fan_in_fan_out = fan_in_fan_out + self.base_layer = base_layer + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), base_layer.in_features))) + self.lora_B = nn.Parameter( + self.weight.new_zeros((base_layer.out_features // len(enable_lora) * sum(enable_lora), + r))) # weights for Conv1D with groups=sum(enable_lora) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros((base_layer.out_features, ), + dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + self.weight = self.base_layer.weight + if getattr(self.base_layer, 'bias', None) is not None: + self.bias = self.base_layer.bias + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def zero_pad(self, x): + result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) + result[self.lora_ind] = x + return result + + def merge_AB(self): + + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + delta_w = F.conv1d(self.lora_A.unsqueeze(0), self.lora_B.unsqueeze(-1), groups=sum(self.enable_lora)).squeeze(0) + return T(self.zero_pad(delta_w)) + + def merge(self, **kwargs): + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0 and any(self.enable_lora): + self.weight.data += self.merge_AB() * self.scaling + + def unmerge(self, **kwargs): + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0 and any(self.enable_lora): + self.weight.data -= self.merge_AB() * self.scaling + self.merged = False + + def forward(self, x: torch.Tensor, **kwargs): + + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + if self.merged or not self.is_activated(self.adapter_name): + return F.linear(x, T(self.weight), bias=self.bias) + else: + result = F.linear(x, T(self.weight), bias=self.bias) + if self.r > 0: + x_dtype = x.dtype + x = x.to(self.lora_A.dtype) + result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling + result = result.to(x_dtype) + return result + + +def mark_lora_as_trainable(model: nn.Module, adapter_name: str, bias: str = 'none') -> None: + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for n, m in model.named_modules(): + if 'lora_' in n and f'.{adapter_name}' in n and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(state_dict, adapter_name: str, bias: str = 'none') -> Dict[str, torch.Tensor]: + if bias == 'none': + to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k} + elif bias == 'all': + to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k or 'bias' in k} + elif bias == 'lora_only': + to_return = {} + for k in state_dict: + if 'lora_' in k: + to_return[k] = state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + return {k: v for k, v in to_return.items() if (('lora_' in k and f'.{adapter_name}' in k) or ('bias' in k))} diff --git a/swift/tuners/mapping.py b/swift/tuners/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..aa17ef89e6af7fca7af3d53aa54958d1a4ee4f94 --- /dev/null +++ b/swift/tuners/mapping.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .adapter import Adapter, AdapterConfig +from .llamapro import LLaMAPro, LLaMAProConfig +from .longlora.longlora import LongLoRA, LongLoRAConfig +from .lora import LoRA, LoRAConfig +from .neftune import NEFTune, NEFTuneConfig +from .part import Part, PartConfig +from .prompt import Prompt, PromptConfig +from .reft import Reft, ReftConfig +from .restuning import ResTuning, ResTuningConfig +from .scetuning.scetuning import SCETuning, SCETuningConfig +from .side import Side, SideConfig + + +class SwiftTuners: + ADAPTER = 'ADAPTER' + PROMPT = 'PROMPT' + LORA = 'LORA' + SIDE = 'SIDE' + RESTUNING = 'RESTUNING' + LONGLORA = 'longlora' + NEFTUNE = 'neftune' + LLAMAPRO = 'LLAMAPRO' + SCETUNING = 'SCETuning' + PART = 'part' + REFT = 'reft' + + +SWIFT_MAPPING = { + SwiftTuners.ADAPTER: (AdapterConfig, Adapter), + SwiftTuners.PROMPT: (PromptConfig, Prompt), + SwiftTuners.LORA: (LoRAConfig, LoRA), + SwiftTuners.SIDE: (SideConfig, Side), + SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning), + SwiftTuners.LONGLORA: (LongLoRAConfig, LongLoRA), + SwiftTuners.NEFTUNE: (NEFTuneConfig, NEFTune), + SwiftTuners.SCETUNING: (SCETuningConfig, SCETuning), + SwiftTuners.LLAMAPRO: (LLaMAProConfig, LLaMAPro), + SwiftTuners.PART: (PartConfig, Part), + SwiftTuners.REFT: (ReftConfig, Reft), +} diff --git a/swift/tuners/neftune.py b/swift/tuners/neftune.py new file mode 100644 index 0000000000000000000000000000000000000000..6476283e5d2348e24823fbef0cd34abb06675308 --- /dev/null +++ b/swift/tuners/neftune.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field + +import torch +from torch import nn + +from swift.utils.logger import get_logger +from .utils import SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class NEFTuneConfig(SwiftConfig): + """ + The configuration class for the NEFTune module. + + NEFTune adds slightly noises to embedding outputs. + See https://arxiv.org/abs/2310.05914 + + Args: + noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0 + """ + noise_alpha: float = field(default=5.0, metadata={'help': 'The noise alpha value used for the NEFTune'}) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.NEFTUNE + + +class NEFTune(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: NEFTuneConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `NEFTuneConfig`""" + for sub_module in model.modules(): + if isinstance(sub_module, torch.nn.Embedding): + + def neftune_hook(module, args, output): + if module.training and getattr(module, 'nef_activated'): + dims = torch.tensor(output.size(-1) * output.size(-2)) + mag_norm = config.noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output + + if hasattr(sub_module, 'nef_activated'): + raise ValueError('NEFTune does not support a second tuner.') + + sub_module.register_forward_hook(neftune_hook) + sub_module.nef_activated = True + + def state_dict_callback(state_dict, adapter_name, **kwargs): + return state_dict + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + for sub_module in module.modules(): + if isinstance(sub_module, torch.nn.Embedding): + sub_module.nef_activated = activate + + @staticmethod + def freeze_model(): + return False + + @staticmethod + def has_additional_modules(): + return False diff --git a/swift/tuners/part.py b/swift/tuners/part.py new file mode 100644 index 0000000000000000000000000000000000000000..e398986f91e3726c7da42594f598cb57dc16fc90 --- /dev/null +++ b/swift/tuners/part.py @@ -0,0 +1,119 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from copy import deepcopy +from dataclasses import dataclass +from types import MethodType +from typing import Dict, Optional + +import torch +from torch import nn + +from swift.utils import get_logger +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class PartConfig(SwiftConfig): + """ + Freeze the model and train a part of it. + + Args: + target_modules(`Optional[str]`): The target modules to be trained in regex format + """ + + target_modules: Optional[str] = None + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.PART + + +class Part(SwiftAdapter): + + @staticmethod + def target_module_matched(module_key: str, config: PartConfig): + return re.fullmatch(config.target_modules, module_key) + + @staticmethod + def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str): + name_list = [name for name, _ in model.named_modules(remove_duplicate=False)] + for name in name_list: + module: nn.Module = model.get_submodule(name) + if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False): + if hasattr(module, 'base_layer'): + module = module.base_layer + + def _forward(self, *args, **kwargs): + child_list = [ + sub_module for name, sub_module in self.named_modules(remove_duplicate=False) + if '_part_' in name + ] + sub_modules = [child for child in child_list if getattr(child, 'activated', False)] + assert len(sub_modules) <= 1 + if len(sub_modules) == 1: + return sub_modules[0].forward(*args, **kwargs) + else: + return self.forward_origin(*args, **kwargs) + + if not hasattr(module, 'forward_origin'): + module.forward_origin = module.forward + module.forward = MethodType(_forward, module) + + new_module = deepcopy(module) + for attr in dir(new_module): + if '_part_' in attr: + delattr(new_module, attr) + new_module.part_name = adapter_name + ActivationMixin.mark_all_sub_modules_as_plugin(new_module) + setattr(module, f'_part_{adapter_name}', new_module) + new_module.requires_grad_(True) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + if f'_part_{adapter_name}.' in key: + if kwargs.get('replace_key', True): + new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') + else: + new_key = key + new_state_dict[new_key] = value + + return new_state_dict + + def mark_trainable_callback(model: nn.Module): + pass + + def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]): + new_state_dict = {} + for name, module in model.named_modules(remove_duplicate=False): + module: nn.Module + if Part.target_module_matched(name, config): + for param_name in state_dict: + if param_name.startswith(name): + end = param_name[len(name):] + if '_part_' not in param_name: + if hasattr(module, 'base_layer'): + new_state_dict[name + f'.base_layer._part_{adapter_name}' + + end] = state_dict[param_name] + else: + new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] + else: + new_state_dict[param_name] = state_dict[param_name] + return new_state_dict + + return SwiftOutput( + config=config, + state_dict_callback=state_dict_callback, + mark_trainable_callback=mark_trainable_callback, + load_state_dict_callback=load_state_dict_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + name_list = [name for name, _ in module.named_modules(remove_duplicate=False)] + for name in name_list: + sub_module: nn.Module = module.get_submodule(name) + if re.fullmatch(f'.*_part_{adapter_name}$', name): + sub_module.activated = activate + SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload) diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py new file mode 100644 index 0000000000000000000000000000000000000000..f561db4fc049d167f87c56bfae28b201dc967b6d --- /dev/null +++ b/swift/tuners/peft.py @@ -0,0 +1,392 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2023-present the HuggingFace Inc. team. +import os.path +from dataclasses import asdict, dataclass, field +from functools import partial, reduce +from types import MethodType +from typing import Dict, Optional + +import json +import peft +import torch +import torch.nn +import transformers +from modelscope import snapshot_download +from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig, + PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig, + PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config, + get_peft_model, get_peft_model_state_dict) +from peft.config import PeftConfigMixin +from peft.tuners import lora +from peft.tuners.adalora import AdaLoraModel, RankAllocator +from peft.tuners.lora import Embedding +from transformers import Trainer + +from swift.utils import get_logger + +try: + from peft import FourierFTModel +except ImportError: + FourierFTModel = None + +try: + from peft import BoneModel +except ImportError: + BoneModel = None + +logger = get_logger() +dispatchers = [] + + +@dataclass +class LoraConfig(peft.LoraConfig): + lora_dtype: Optional[str] = field( + default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'}) + + lorap_lr_ratio: Optional[float] = field(default=None, metadata={'help': 'The lr ratio of lora_B in lora+'}) + + lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'}) + + def to_peft_config(self) -> peft.LoraConfig: + _dict = asdict(self) + _dict.pop('lora_dtype') + _dict.pop('lorap_lr_ratio') + _dict.pop('lorap_emb_lr') + return peft.LoraConfig(**_dict) + + def save_pretrained(self, save_directory: str, **kwargs) -> None: + self.to_peft_config().save_pretrained(save_directory, **kwargs) + additional_args = { + 'lora_dtype': self.lora_dtype, + 'lorap_lr_ratio': self.lorap_lr_ratio, + 'lorap_emb_lr': self.lorap_emb_lr, + } + with open(os.path.join(save_directory, 'additional_config.json'), 'w', encoding='utf-8') as f: + json.dump(additional_args, f) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs): + if hasattr(PeftConfigMixin, 'from_pretrained_origin'): + self = PeftConfigMixin.from_pretrained_origin(pretrained_model_name_or_path, subfolder, **kwargs) + else: + self = super(LoraConfig, cls).from_pretrained(pretrained_model_name_or_path, subfolder, **kwargs) + + if type(self) == peft.LoraConfig: + self = LoraConfig(**self.to_dict()) + + if os.path.isfile(os.path.join(pretrained_model_name_or_path, 'additional_config.json')): + with open( + os.path.join(pretrained_model_name_or_path, 'additional_config.json'), 'r', encoding='utf-8') as f: + _json = json.load(f) + for key, value in _json.items(): + setattr(self, key, value) + + return self + + +def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs): + all_supported_names = ('linear', ) + all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear) + target_modules = getattr(peft_config, 'target_modules', None) + if target is None: + return + + if isinstance(target_modules, str) and not any( + [name in target.__class__.__name__.lower() + for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]): + return + + if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear': + return + + return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs) + + +def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str): + if lora_dtype is not None: + torch_dtype = eval(f'torch.{lora_dtype}') + if hasattr(target, 'lora_A') and adapter_name in target.lora_A: + target.lora_A[adapter_name].to(torch_dtype) + target.lora_B[adapter_name].to(torch_dtype) + if hasattr(target, 'lora_embedding_A') and adapter_name in target.lora_embedding_A: + target.lora_embedding_A[adapter_name].to(torch_dtype) + target.lora_embedding_B[adapter_name].to(torch_dtype) + + +def create_optimizer_param_groups(self: PeftModel, **defaults): + if not isinstance(self.peft_config[self.active_adapter], + LoraConfig) or self.peft_config[self.active_adapter].lorap_lr_ratio is None: + return None + + def get_module(name): + parent_idx = 2 if 'lora' in name else 1 + module_names = name.split(sep='.')[:-parent_idx] + module = reduce(getattr, module_names, self.base_model) + return module + + param_groups = { + 'groupA': {}, + 'groupB': {}, + 'groupB_no_decay': {}, + 'embedding': {}, + } + + decay_parameters = Trainer.get_decay_parameter_names(None, self.base_model) + for name, param in self.base_model.named_parameters(): + if not param.requires_grad: + continue + + module = get_module(name) + if isinstance(module, Embedding): + param_groups['embedding'][name] = param + elif 'lora_B' in name or param.ndim == 1: + if name in decay_parameters: + param_groups['groupB'][name] = param + else: + param_groups['groupB_no_decay'][name] = param + else: + param_groups['groupA'][name] = param + + lr = defaults['lr'] + weight_decay = defaults.get('weight_decay', 0.0) + + param_groups = [ + { + 'params': list(param_groups['groupA'].values()), + 'weight_decay': weight_decay, + 'lr': lr, + }, + { + 'params': list(param_groups['embedding'].values()), + 'weight_decay': weight_decay, + 'lr': self.peft_config[self.active_adapter].lorap_emb_lr, + }, + { + 'params': list(param_groups['groupB'].values()), + 'weight_decay': weight_decay, + 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio, + }, + { + 'params': list(param_groups['groupB_no_decay'].values()), + 'weight_decay': 0.0, + 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio, + }, + ] + return param_groups + + +def adalora_forward(self, *args, **kwargs): + from peft.utils.integrations import gather_params_ctx + outputs = self.model.forward(*args, **kwargs) + + if (getattr(outputs, 'loss', None) is not None) and isinstance(outputs.loss, torch.Tensor): + # Calculate the orthogonal regularization + orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight + + if orth_reg_weight <= 0: + raise ValueError('orth_reg_weight should be greater than 0. ') + + regu_loss = 0 + num_param = 0 + for n, p in self.model.named_parameters(): + if ('lora_A' in n or 'lora_B' in n) and self.trainable_adapter_name in n: + if p.shape == torch.Size([0]): + with gather_params_ctx(p, fwd_module=self): + para_cov = p @ p.T if 'lora_A' in n else p.T @ p + else: + para_cov = p @ p.T if 'lora_A' in n else p.T @ p + I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741 + I.requires_grad = False + num_param += 1 + if isinstance(regu_loss, torch.Tensor): + regu_loss = regu_loss.to(para_cov.device) + regu_loss += torch.norm(para_cov - I, p='fro') + if num_param > 0: + regu_loss = regu_loss / num_param + else: + regu_loss = 0 + if isinstance(regu_loss, torch.Tensor) and isinstance(outputs.loss, torch.Tensor): + regu_loss = regu_loss.to(outputs.loss.device) + outputs.loss += orth_reg_weight * regu_loss + return outputs + + +def adalora_mask_to_budget(self, model, budget): + value_ipt = {} + vector_ipt = {} + triplet_ipt = {} + # Get the importance score for A, E, B + for n, p in model.named_parameters(): + if f'lora_A.{self.adapter_name}' in n: + entry_ipt = self._element_score(n) + comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True) + name_m = n.replace('lora_A', '%s') + if name_m not in vector_ipt: + vector_ipt[name_m] = [comb_ipt] + else: + vector_ipt[name_m].append(comb_ipt) + if f'lora_B.{self.adapter_name}' in n: + entry_ipt = self._element_score(n) + comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1) + name_m = n.replace('lora_B', '%s') + if name_m not in vector_ipt: + vector_ipt[name_m] = [comb_ipt] + else: + vector_ipt[name_m].append(comb_ipt) + if f'lora_E.{self.adapter_name}' in n: + entry_ipt = self._element_score(n) + name_m = n.replace('lora_E', '%s') + value_ipt[name_m] = entry_ipt + + all_score = [] + # Calculate the score for each triplet + for name_m in vector_ipt: + ipt_E = value_ipt[name_m] + ipt_AB = torch.cat(vector_ipt[name_m], dim=1) + sum_ipt = self._combine_ipt(ipt_E, ipt_AB) + name_E = name_m % 'lora_E' + triplet_ipt[name_E] = sum_ipt.view(-1, 1) + sum_ipt = sum_ipt.view(-1) + if all_score: + sum_ipt = sum_ipt.to(all_score[0].device) + all_score.append(sum_ipt) + + # Get the threshold by ranking ipt + mask_threshold = torch.kthvalue( + torch.cat(all_score), + k=self.init_bgt - budget, + )[0].item() + + rank_pattern = {} + # Mask the unimportant triplets + with torch.no_grad(): + for n, p in model.named_parameters(): + if f'lora_E.{self.adapter_name}' in n: + p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0) + rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist() + return rank_pattern + + +def keep_device_forward(self, *args, **kwargs): + x = args[0] + if self.weight.device != x.device: + return self.forward_origin(x.to(self.weight.device), *args[1:], **kwargs) + else: + return self.forward_origin(*args, **kwargs) + + +def hot_patch_peft_module(): + from peft.tuners.lora import LoraLayer + if hasattr('LoraModel', '_create_and_replace_origin'): + return + + # Fix Lora does not support NonDynamicallyQuantizableLinear + LoraModel._create_and_replace_origin = LoraModel._create_and_replace + LoraModel._create_and_replace = _create_and_replace_hook + AdaLoraModel._create_and_replace_origin = AdaLoraModel._create_and_replace + AdaLoraModel._create_and_replace = _create_and_replace_hook + VeraModel._create_and_replace_origin = VeraModel._create_and_replace + VeraModel._create_and_replace = _create_and_replace_hook + BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace + BOFTModel._create_and_replace = _create_and_replace_hook + if FourierFTModel is not None: + FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace + FourierFTModel._create_and_replace = _create_and_replace_hook + if BoneModel is not None: + BoneModel._create_and_replace_origin = BoneModel._create_and_replace + BoneModel._create_and_replace = _create_and_replace_hook + + # Support type conversion + def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str): + + self.__init_origin__(model, config, adapter_name) + active_adapters = self.active_adapter + if isinstance(active_adapters, str): + active_adapters = [active_adapters] + for active_adapter in active_adapters: + active_config = config[active_adapter] if isinstance(config, dict) else config + if hasattr(active_config, 'lora_dtype'): + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + _convert_dtype(module, active_adapter, active_config.lora_dtype) + for lora in list(module.lora_A.values()) + list(module.lora_B.values()): + if not hasattr(lora, 'forward_origin'): + lora.forward_origin = lora.forward + lora.forward = MethodType(keep_device_forward, lora) + + LoraModel.__init_origin__ = LoraModel.__init__ + LoraModel.__init__ = __new_init__ + + # Support LoRA+ + PeftModel.create_optimizer_param_groups = create_optimizer_param_groups + + PeftConfigMixin.from_pretrained_origin = PeftConfigMixin.from_pretrained + PeftConfigMixin.from_pretrained = LoraConfig.from_pretrained + + # Compatible with SwiftModel + def dummy_function(*args, **kwargs): + logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.') + + PeftModel.activate_adapter = PeftModel.set_adapter + PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter') + PeftModel.set_active_adapters = partial(dummy_function, func='set_active_adapters') + + # Fix adalora does not support device_map + AdaLoraModel.forward = adalora_forward + RankAllocator.mask_to_budget = adalora_mask_to_budget + + +def get_wrapped_class(module_class): + """Get a custom wrapper class for peft classes to download the models from the ModelScope hub + + Args: + module_class: The actual module class + + Returns: + The wrapper + """ + + class PeftWrapper(module_class): + + @classmethod + def from_pretrained(cls, model, model_id, *args, revision: Optional[str] = None, **kwargs): + if not os.path.exists(model_id): + model_id = snapshot_download(model_id, revision=revision) + return module_class.from_pretrained(model, model_id, *args, **kwargs) + + PeftWrapper.__name__ = module_class.__name__ + PeftWrapper.__qualname__ = module_class.__qualname__ + return PeftWrapper + + +def wrap_module(module): + if not hasattr(module, 'from_pretrained'): + return module + + return get_wrapped_class(module) + + +hot_patch_peft_module() +PeftModel = wrap_module(PeftModel) +PeftConfig = wrap_module(PeftConfig) +PeftModelForSeq2SeqLM = wrap_module(PeftModelForSeq2SeqLM) +PeftModelForSequenceClassification = wrap_module(PeftModelForSequenceClassification) +PeftModelForTokenClassification = wrap_module(PeftModelForTokenClassification) +PeftModelForCausalLM = wrap_module(PeftModelForCausalLM) +PromptEncoderConfig = wrap_module(PromptEncoderConfig) +PromptTuningConfig = wrap_module(PromptTuningConfig) +PrefixTuningConfig = wrap_module(PrefixTuningConfig) +PromptLearningConfig = wrap_module(PromptLearningConfig) +LoraConfig = wrap_module(LoraConfig) +AdaLoraConfig = wrap_module(AdaLoraConfig) +LoHaConfig = wrap_module(LoHaConfig) +LoKrConfig = wrap_module(LoKrConfig) +LoftQConfig = wrap_module(LoftQConfig) +OFTConfig = wrap_module(OFTConfig) +BOFTConfig = wrap_module(BOFTConfig) +VeraConfig = wrap_module(VeraConfig) +OFTConfig = wrap_module(OFTConfig) +get_peft_config = get_peft_config +get_peft_model_state_dict = get_peft_model_state_dict +get_peft_model = get_peft_model diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3d1ab4e80eadbd9f6fb176e672d73a316da2cb --- /dev/null +++ b/swift/tuners/prompt.py @@ -0,0 +1,205 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import re +import types +from dataclasses import dataclass, field +from typing import List, Union + +import torch +from torch import nn + +from swift.utils import get_logger +from swift.utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class PromptConfig(SwiftConfig): + """ + The configuration class for the prompt module. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Here we apply the VPT to other fields. + + Args: + dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks + or down-sample blocks + target_modules(str): The layer module to be replaced, in regex format + embedding_pos(Union[str, int]): The position of the embedding tensor + attention_mask_pos(Union[str, int]): The position of the attention mask + attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask + prompt_length(int): The length of the prompt tokens + attach_front(bool): When set to True, prompt is attached in front of the embedding + extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs + """ + + dim: Union[int, List[int]] = field(default=None, metadata={'help': 'The dimension of the hidden states'}) + + target_modules: str = field(default=None, metadata={'help': 'The layer module to be replaced, in regex format'}) + + embedding_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the embedding tensor'}) + + attention_mask_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the attention mask'}) + + attention_mask_value: Union[float, int, bool] = field( + default=0., metadata={'help': 'The value to pad to the attention mask'}) + + prompt_length: int = field(default=16, metadata={'help': 'The length of the prompt tokens'}) + + attach_front: bool = field( + default=True, metadata={'help': 'When set to True, prompt is attached in front of the embedding'}) + + extract_embedding: bool = field( + default=False, + metadata={'help': 'Whether the embedding is extracted at final stage to keep the same dims with inputs'}) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.PROMPT + + +class Prompt(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: PromptConfig, adapter_name: str): + module_keys = [key for key, _ in model.named_modules()] + match_module_keys = [] + for module_key in module_keys: + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, module_key) + else: + target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules) + if target_module_found: # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + if isinstance(config.embedding_pos, int): + input_embedding = args[config.embedding_pos] + else: + input_embedding = kwargs[config.embedding_pos] + + input_embedding = getattr(self, f'prompt_{adapter_name}').forward(input_embedding) + if isinstance(config.embedding_pos, int): + args = type(args)( + args[0:config.embedding_pos] + (input_embedding, ) + args[config.embedding_pos + 1:]) + else: + kwargs[config.embedding_pos] = input_embedding + + if config.attention_mask_pos: + attention_mask = None + if isinstance(config.attention_mask_pos, int): + attention_mask = args[config.attention_mask_pos] + elif isinstance(config.attention_mask_pos, str): + attention_mask = kwargs[config.attention_mask_pos] + + if attention_mask is not None: + attention_mask = getattr(self, + f'prompt_{adapter_name}').patch_attention_mask(attention_mask) + if isinstance(config.attention_mask_pos, int): + args = type(args)( + args[0:config.attention_mask_pos] + (attention_mask, ) + + args[config.attention_mask_pos + 1:]) + else: + kwargs[config.attention_mask_pos] = attention_mask + + forward_output = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + if config.extract_embedding: + forward_output = getattr(self, f'prompt_{adapter_name}').extract(forward_output) + + return forward_output + + setattr(module, f'forward_origin_{adapter_name}', module.forward) + module.forward = types.MethodType(_forward, module) + if isinstance(config.dim, list): + input_dim = config.dim[len(match_module_keys)] + else: + input_dim = config.dim + prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, module_key, + config.prompt_length, config.attention_mask_value, config.attach_front) + setattr(module, f'prompt_{adapter_name}', prompt_module) + logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}') + match_module_keys.append(module_key) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key} + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'prompt_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class PromptModule(nn.Module, ActivationMixin): + """The implementation of vision prompt tuning method. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Args: + dim: An integer indicating the embedding dimension. + layer_num: An integer indicating number of layers. + prompt_length: An integer indicating the length of vision prompt tuning. + """ + + def __init__(self, dim, layer_num, adapter_name, module_key, prompt_length=None, mask_values=0., attach_front=True): + super(PromptModule, self).__init__() + super(nn.Module, self).__init__(module_key) + self.dim = dim + self.layer_num = layer_num + self.adapter_name = adapter_name + self.prompt_length = prompt_length + self.mask_values = mask_values + self.attach_front = attach_front + self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) + nn.init.xavier_uniform_(self.prompt_token) + self.mark_all_sub_modules_as_plugin() + + def forward(self, x): + if not self.is_activated(self.adapter_name): + return x + prompt_token = self.prompt_token.expand(x.shape[0], -1, -1).to(x.device, x.dtype) + + if self.layer_num == 0: + if self.attach_front: + x = torch.cat((prompt_token, x), dim=1) + else: + x = torch.cat((x, prompt_token), dim=1) + else: + if self.attach_front: + x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), dim=1) + else: + x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1) + return x + + def patch_attention_mask(self, m): + if not self.is_activated(self.adapter_name): + return m + prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device) + if self.attach_front: + return torch.cat((prefix_attention_mask, m), dim=-1) + else: + return torch.cat((m, prefix_attention_mask), dim=-1) + + def extract(self, x): + if self.attach_front: + return x[:, self.prompt_length:, :] + else: + return x[:, :-self.prompt_length, :] diff --git a/swift/tuners/reft.py b/swift/tuners/reft.py new file mode 100644 index 0000000000000000000000000000000000000000..8179b61ccda8b81241cd583ec039c70665e4077a --- /dev/null +++ b/swift/tuners/reft.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass +from types import MethodType +from typing import List, Literal, Optional + +import json +import torch +from torch import nn + +from swift.utils import get_logger, patch_getattr +from .utils import SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class ReftConfig(SwiftConfig): + """ + Train a model with Reft. + Paper: https://arxiv.org/pdf/2404.03592 + + Args: + model_type(`Optional[str]`): The model_type to find down_proj/layers. + layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`. + layers (`Optional[List[int]]`): The layer number to inject. + r(`int`): The rank of Reft. + intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention', + 'ConsreftIntervention', 'LobireftIntervention', + 'DireftIntervention', 'NodireftIntervention']`): The intervention type, + default LoreftIntervention + args (`Optional[str]`): Other reft_args in json-string format + """ + + model_type: Optional[str] = None + layer_key: Optional[str] = None + layers: Optional[List[int]] = None + r: int = 4 + intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention', + 'LobireftIntervention', 'DireftIntervention', + 'NodireftIntervention'] = 'LoreftIntervention' + args: Optional[str] = None + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.REFT + if self.args: + self.args = json.loads(self.args) + else: + self.args = {} + + +class Reft(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str): + from swift.utils.import_utils import is_pyreft_available + if not is_pyreft_available(): + raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`') + + import pyreft + from pyreft import ReftModel + from pyreft.interventions import LowRankRotateLayer + from pyreft import ( + NoreftIntervention, + LoreftIntervention, + ConsreftIntervention, + LobireftIntervention, + DireftIntervention, + NodireftIntervention, + ) + + intervention_mapping = { + 'NoreftIntervention': NoreftIntervention, + 'LoreftIntervention': LoreftIntervention, + 'ConsreftIntervention': ConsreftIntervention, + 'LobireftIntervention': LobireftIntervention, + 'DireftIntervention': DireftIntervention, + 'NodireftIntervention': NodireftIntervention, + } + + patch_getattr(ReftModel, 'model') + + def forward(self, x): + self.to(x.device) + return self.forward_origin(x) + + def forward2(self, base, source=None, subspaces=None): + self.to(base.device) + return self.forward_origin(base, source, subspaces) + + if not hasattr(LowRankRotateLayer, 'forward_origin'): + LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward + LowRankRotateLayer.forward = forward + NoreftIntervention.forward_origin = NoreftIntervention.forward + NoreftIntervention.forward = forward2 + LoreftIntervention.forward_origin = LoreftIntervention.forward + LoreftIntervention.forward = forward2 + ConsreftIntervention.forward_origin = ConsreftIntervention.forward + ConsreftIntervention.forward = forward2 + LobireftIntervention.forward_origin = LobireftIntervention.forward + LobireftIntervention.forward = forward2 + DireftIntervention.forward_origin = DireftIntervention.forward + DireftIntervention.forward = forward2 + NodireftIntervention.forward_origin = NodireftIntervention.forward + NodireftIntervention.forward = forward2 + + module_list_key = config.layer_key + if module_list_key is None: + model_key_mapping = Reft.get_model_key_mapping(config.model_type, config) + module_list_key = model_key_mapping.module_list + logger.info(f'Applying Reft to module: {module_list_key}') + module_list: nn.ModuleList = model.get_submodule(module_list_key) + representations = [] + for idx, layer in enumerate(module_list): + if config.layers and idx not in config.layers: + continue + intervention_config = { + 'layer': + idx, + 'component': + module_list_key + f'[{idx}].output', + 'low_rank_dimension': + config.r, + 'intervention': + intervention_mapping[config.intervention_type]( + embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args) + } + representations.append(intervention_config) + + reft_config = pyreft.ReftConfig(representations=representations) + reft_model = pyreft.get_reft_model(model, reft_config, set_device=False) + reft_model.reft_config = reft_model.config + reft_model.config = reft_model.model.config + + def _pre_forward_hook(module, args, kwargs): + if 'base' in kwargs: + return args, kwargs + + if 'input_ids' not in kwargs: + raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.') + # run intervened forward pass + unit_locations = None + if 'intervention_locations' in kwargs: + if kwargs['intervention_locations'].dim() == 3: + unit_locations = { + 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) + } + else: + # this is dummy for lora only baseline + unit_locations = {'sources->base': (None, 0)} + kwargs = { + 'base': { + 'input_ids': kwargs['input_ids'], + 'attention_mask': kwargs['attention_mask'] + }, + 'unit_locations': unit_locations, + 'labels': kwargs['labels'], + 'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None + } + return args, kwargs + + def _post_forward_hook(module, args, kwargs, outputs): + return outputs[1] + + def _generate(self, **kwargs): + # run intervened forward pass + unit_locations = None + if 'intervention_locations' in kwargs: + if kwargs['intervention_locations'].dim() == 3: + unit_locations = { + 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) + } + else: + # this is dummy for lora only baseline + unit_locations = {'sources->base': (None, 0)} + + _kwargs = { + 'base': { + 'input_ids': kwargs.pop('input_ids'), + 'attention_mask': kwargs.pop('attention_mask') + }, + 'unit_locations': unit_locations, + 'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None + } + _kwargs = {**_kwargs, **kwargs} + return self.generate_origin(**_kwargs)[1] + + reft_model.generate_origin = reft_model.generate + reft_model.generate = MethodType(_generate, reft_model) + reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) + reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True) + + def save_callback(swift_model, model_dir, adapter_name): + reft_model.save_intervention(save_directory=model_dir, include_model=False) + + def mark_trainable_callback(model): + return + + def load_callback(swift_model, model_dir, adapter_name): + reft_model.load_intervention(model_dir, include_model=False) + + return SwiftOutput( + model=reft_model, + config=config, + mark_trainable_callback=mark_trainable_callback, + save_callback=save_callback, + load_callback=load_callback) + + @staticmethod + def has_additional_modules(): + return True + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + assert activate, 'ReFT does not support deactivate' diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9def230a9c2e228d306b7304c4e006680c40ad --- /dev/null +++ b/swift/tuners/restuning.py @@ -0,0 +1,327 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import re +import types +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from swift.utils import get_logger +from swift.utils.torch_utils import find_sub_module +from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class ResTuningConfig(SwiftConfig): + """ + The configuration class for the ResTuning module. + + ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework. + 'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone' + by Jiang et al.(2023) + See + + Args: + dims(`Union[List[int], int]`): The dimensions of the hidden states + root_modules(`str`): The root module to be replaced, can a regex string + root_modules_hook(`str`): The hook type of root modules, can be "input" or "output" + stem_modules(`Union[List[str], str]`): The stem modules to be replaced, + can a regex string or name list of full match format + stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output" + target_modules(`str`): The target module to be replaced, can a regex string + target_modules_hook(`str`): The hook type of target modules, can be "input" or "output" + tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module, + can a string or customized config + use_upsample(bool): Whether to use auxiliary upsample module + upsample_out_channels(List[int]): The channels if `use_upsample` + zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner. + + """ + + dims: Optional[Union[List[int], int]] = field( + default=None, metadata={'help': 'The dimensions of the hidden states'}) + + root_modules: str = field( + default=None, + metadata={ + 'help': + 'The root module to be replaced, can a regex string (use the first matching module) or full match format' + }) + + root_modules_hook: str = field( + default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'}) + + stem_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'}) + + stem_modules_hook: str = field( + default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'}) + + target_modules: str = field( + default=None, + metadata={ + 'help': + 'The target module to be replaced, can a regex string (use the first matching module) or full match format' + }) + + target_modules_hook: str = field( + default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'}) + + target_hidden_pos: Union[int, str] = field( + default=None, metadata={'help': 'The position of the hidden state for target modules output'}) + + tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field( + default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'}) + + use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'}) + + upsample_out_channels: List[int] = field( + default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'}) + + zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'}) + + use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'}) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.RESTUNING + self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos + + +class ResTuning(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `ResTuningConfig`""" + + def _forward_seq(self, input, *args, **kwargs): + for idx, module in enumerate(self): + if idx >= len(self.origin_module_keys): + continue + input = module(input) + return input + + def _forward_target(self, *args, **kwargs): + if self.target_modules_hook == 'input': + if isinstance(self.target_hidden_pos, int): + args = list(args) + _arg = args[self.target_hidden_pos] + else: + _arg = kwargs[self.target_hidden_pos] + args_main = _forward_restuning(self, _arg) + if isinstance(self.target_hidden_pos, int): + args[self.target_hidden_pos] = args_main + else: + kwargs[self.target_hidden_pos] = args_main + args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + else: + _args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + _arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main + args_main = _forward_restuning(self, _arg) + if type(_args_main) != type(args_main): + _args_main[self.target_hidden_pos] = args_main + args_main = _args_main + return args_main + + def _forward_restuning(self, origin_arg): + probe_results = [] + root_module_ins = self.root_module_ins_list[0] + stem_module_ins_list = self.stem_module_ins_list + top_module = model.get_submodule('') + if root_module_ins: + if root_module_ins.root_modules_hook == 'input': + probe_results.append(root_module_ins.probe_input_data) + else: + probe_results.append(root_module_ins.probe_output_data) + for i, st_mod in enumerate(stem_module_ins_list): + if i == 0 and root_module_ins is None: + probe_results.append(st_mod.probe_input_data) + if st_mod.stem_modules_hook == 'input': + probe_results.append(st_mod.probe_input_data) + else: + probe_results.append(st_mod.probe_output_data) + args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg) + return args_main + + # 1. Matching the root module + module_keys = [key for key, _ in model.named_modules()] + root_module_ins_list = [] + if config.root_modules: + for module_key in module_keys: + if re.fullmatch(config.root_modules, module_key): + root_module = model.get_submodule(module_key) + logger.info(f'Matching root module [{module_key}] of type {type(root_module)}') + if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(root_module)} may not be supported because of its customized forward') + if config.root_modules_hook == 'input': + root_module.register_forward_pre_hook(probe_input_pre_hook) + else: + root_module.register_forward_hook(probe_output_hook) + root_module.root_modules_hook = config.root_modules_hook + root_module_ins_list.append(root_module) + break + if len(root_module_ins_list) == 0: + logger.error('Cannot match root modules') + + # 2. Matching the stem module + stem_module_ins_list = [] + stem_module_ins_index = [] + for module_key in module_keys: + if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \ + (isinstance(config.stem_modules, list) and module_key in config.stem_modules): + stem_module = model.get_submodule(module_key) + if isinstance(config.stem_modules, list): + stem_module_ins_index.append(config.stem_modules.index(module_key)) + logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}') + if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(stem_module)} may not be supported because of its customized forward') + if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0: + stem_module.register_forward_pre_hook(probe_input_pre_hook) + if config.stem_modules_hook == 'input': + stem_module.register_forward_pre_hook(probe_input_pre_hook) + else: + stem_module.register_forward_hook(probe_output_hook) + stem_module.stem_modules_hook = config.stem_modules_hook + stem_module_ins_list.append(stem_module) + if isinstance(config.stem_modules, list): + stem_module_ins_list = [ + stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index)) + ] + depth = len(stem_module_ins_list) + if len(stem_module_ins_list) == 0: + raise Exception('Cannot match source modules') + + # 3. Init restuning module + if len(stem_module_ins_list) != 0: + top_module = model.get_submodule('') + restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample, + config.upsample_out_channels, config.zero_init_last, + config.tuner_cfg) + setattr(top_module, f'restuning_{adapter_name}', restuning_module) + + # 4. Matching the target module + target_module_ins = None + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): + tgt_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}') + if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): + raise Exception( + f'Type of {type(tgt_module)} may not be supported because of its customized forward') + + tgt_module.target_modules_hook = config.target_modules_hook + tgt_module.target_hidden_pos = config.target_hidden_pos + tgt_module.root_module_ins_list = root_module_ins_list + tgt_module.stem_module_ins_list = stem_module_ins_list + target_module_ins = tgt_module + + if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'): + tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys())) + + setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module)) + else: + setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) + tgt_module.forward = types.MethodType(_forward_target, tgt_module) + if target_module_ins is None: + raise Exception('Cannot match target modules') + + def state_dict_callback(state_dict, adapter_name, **kwargs): + return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key} + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'restuning_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class ResTuningBypassModule(nn.Module, ActivationMixin): + """The implementation of ResTuningBypass method. + """ + + def __init__( + self, + dims, + depth, + adapter_name, + use_upsample=False, + upsample_out_channels=None, + zero_init_last=False, + tuner_cfg=None, + ): + super(ResTuningBypassModule, self).__init__() + super(nn.Module, self).__init__('') + self.adapter_name = adapter_name + + self.bypass_blocks = nn.Sequential(*[ + ResTunerBypassBlock( + dim=dims[i] if isinstance(dims, list) else dims, + layer_num=i, + depth=depth, + use_upsample=use_upsample, + upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list + ) else upsample_out_channels, + zero_init_last=zero_init_last, + tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth) + ]) + self.mark_all_sub_modules_as_plugin() + + def forward(self, x_list, origin_arg, **kwargs): + if not self.is_activated(self.adapter_name): + return origin_arg + x_bypass = detach_tensors(x_list.pop(0)) + x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass + x_list = detach_tensors(x_list) + x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list] + for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)): + target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None + x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs) + return x_bypass + + +class ResTunerBypassBlock(nn.Module): + + def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs): + super().__init__() + self.layer_num = layer_num + self.depth = depth + + if isinstance(tuner_cfg, str): + lateral_cfg = tuner_cfg + vertical_cfg = tuner_cfg + aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None + elif isinstance(tuner_cfg, dict): + lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None + vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None + aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None + + self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs) + self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs) + if aux_cfg and len(aux_cfg) != 0: + self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs) + + def forward(self, x_stem, x_bypass, target_size=None, **kwargs): + x_lateral = self.lateral_tuner(x_stem) + x_vertical = self.vertical_tuner(x_bypass) + + x_bypass_out = x_lateral + x_vertical + if hasattr(self, 'aux_tuner'): + x_bypass_out = self.aux_tuner(x_bypass_out, target_size) + return x_bypass_out diff --git a/swift/tuners/restuning_components.py b/swift/tuners/restuning_components.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4f53df7f789316edbc8858eebb6b1319d93214 --- /dev/null +++ b/swift/tuners/restuning_components.py @@ -0,0 +1,351 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from swift.utils.logger import get_logger + +logger = get_logger() + + +class ResTuner(nn.Module): + + def __init__(self, dim=None, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg={}, **kwargs): + super().__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + self.stage = stage + self.tuner_cfg = tuner_cfg + + if (isinstance(tuner_cfg, str) and tuner_cfg == 'res_adapter') or \ + (isinstance(tuner_cfg, dict) and 'res_adapter' in tuner_cfg): + tuner_cfg = tuner_cfg['res_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg + self.tuner = ResAdapter( + dim=dim, + layer_num=layer_num, + depth=depth, + zero_init_last=zero_init_last, + stage=stage, + tuner_cfg=tuner_cfg, + **kwargs) + elif (isinstance(tuner_cfg, str) and tuner_cfg == 'res_group_adapter') or \ + (isinstance(tuner_cfg, dict) and 'res_group_adapter' in tuner_cfg): + tuner_cfg = tuner_cfg['res_group_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg + self.tuner = ResGroupAdapter( + dim=dim, + layer_num=layer_num, + depth=depth, + zero_init_last=zero_init_last, + stage=stage, + tuner_cfg=tuner_cfg, + **kwargs) + elif (isinstance(tuner_cfg, str) and tuner_cfg == 'upsample') or \ + (isinstance(tuner_cfg, dict) and 'upsample' in tuner_cfg): + tuner_cfg = tuner_cfg['upsample'] if isinstance(tuner_cfg, dict) else tuner_cfg + if 'upsample_out_channels' in kwargs: + out_channels = kwargs['upsample_out_channels'] + use_conv = True if out_channels else False + else: + out_channels = dim + use_conv = False + self.tuner = Upsample( + channels=dim, use_conv=use_conv, out_channels=out_channels, tuner_cfg=tuner_cfg, **kwargs) + else: + self.tuner = Identity() + + def forward(self, x, *args, **kwargs): + if self.tuner_cfg == 'zero' or 'zero' in self.tuner_cfg: + x_out = 0.0 + else: + x_out = self.tuner(x, *args, **kwargs) + return x_out + + +class ResAdapter(nn.Module): + + def __init__(self, + dim, + layer_num=-1, + depth=-1, + zero_init_last=False, + stage='', + tuner_cfg=None, + act_layer=nn.GELU, + **kwargs): + super(ResAdapter, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + + self.adapter_length = tuner_cfg['adapter_length'] if 'adapter_length' in tuner_cfg else 32 + self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None + self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None + + self.adapter_length = self.adapter_length[self.layer_num] if isinstance(self.adapter_length, + list) else self.adapter_length + assert isinstance(self.adapter_length, int) or (isinstance(self.adapter_length, tuple) + and len(self.adapter_length) == 3) + if isinstance(self.adapter_length, int): + self.ln1 = nn.Linear(dim, self.adapter_length) + else: + self.ln1 = nn.Linear(self.adapter_length[0], self.adapter_length[1]) + self.activate = act_layer() + if isinstance(self.adapter_length, int): + self.ln2 = nn.Linear(self.adapter_length, dim) + else: + self.ln2 = nn.Linear(self.adapter_length[1], self.adapter_length[2]) + dim = self.adapter_length[2] + + self._xavier_init_weights(self.ln1) + if zero_init_last and layer_num == depth - 1: + self._zero_init_weights(self.ln2) + else: + self._xavier_init_weights(self.ln2) + + self.scaling = init_weight_type(dim, self.adapter_weight) + self._prepared = False + + def _zero_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def _kaiming_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + nn.init.normal_(m.bias) + + def _xavier_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + if not self._prepared: + self.ln1.to(x.device) + self.activate.to(x.device) + self.ln2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.ln1.weight.dtype) + x_shortcut = x + if len(x_shortcut.size()) == 4: + B, C, N1, N2 = x.size() + x = x.view(x_shortcut.size()[0], x_shortcut.size()[1], -1).permute(0, 2, 1) + + x_adapter = self.ln2(self.activate(self.ln1(x))) + + if self.adapter_weight: + x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) + + if len(x_shortcut.size()) == 4: + x_adapter = x_adapter.permute(0, 2, 1).view(x_shortcut.size()[0], + x_adapter.size()[-1], + x_shortcut.size()[2], + x_shortcut.size()[3]) + x_out = x_shortcut + x_adapter + return x_out.to(x_dtype) + + +class ResGroupAdapter(nn.Module): + + def __init__(self, + dim, + layer_num=-1, + depth=-1, + zero_init_last=False, + stage='', + tuner_cfg=None, + act_layer=nn.GELU, + **kwargs): + super(ResGroupAdapter, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + + self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None + self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None + + self.adapter_dim = tuner_cfg['dim'] if 'dim' in tuner_cfg else dim + self.adapter_head = tuner_cfg['head'] if 'head' in tuner_cfg else 4 + self.adapter_scale_factor = tuner_cfg['scale_factor'] if 'scale_factor' in tuner_cfg else 2 + + assert self.adapter_dim % self.adapter_head == 0, 'adapter dim should be divisible by adapter head' + self.dim_mlp = self.adapter_dim // self.adapter_head + + self.ln1 = nn.Linear(self.dim_mlp, self.dim_mlp * self.adapter_scale_factor) + self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, self.dim_mlp) + self.activate = act_layer() + + self._kaiming_init_weights(self.ln1) + if zero_init_last and layer_num == depth - 1: + self._zero_init_weights(self.ln2) + else: + self._kaiming_init_weights(self.ln2) + self.scaling = init_weight_type(dim, self.adapter_weight) + self._prepared = False + + def _zero_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def _kaiming_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + nn.init.normal_(m.bias) + + def _xavier_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + if not self._prepared: + self.ln1.to(x.device) + self.activate.to(x.device) + self.ln2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.ln1.weight.dtype) + x_shortcut = x + + batch, inner_dim, height, width = x.shape + + x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + x_adapter = rearrange(x_adapter, 'b n (c h) -> (b h) n c', h=self.adapter_head) + x_adapter = self.ln2(self.activate(self.ln1(x_adapter))) + x_adapter = rearrange(x_adapter, '(b h) n c -> b n (c h)', h=self.adapter_head) + + if self.adapter_weight: + x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) + + x_adapter = x_adapter.reshape(batch, height, width, -1).permute(0, 3, 1, 2).contiguous() + x_out = x_shortcut + x_adapter + + return x_out.to(x_dtype) + + +class Identity(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, inputs, *args, **kwargs): + return inputs + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, **kwargs): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + self.init_weights() + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Conv2d): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + self.apply(_init_weights) + + def forward(self, x, target_size=None, *args, **kwargs): + assert x.shape[1] == self.channels + if target_size is None: + x = F.interpolate(x.float(), scale_factor=2, mode='nearest').type_as(x) + else: + x = F.interpolate(x.float(), target_size, mode='nearest').type_as(x) + if self.use_conv: + x = self.conv(x) + return x + + +def init_weight_type(dim, weight_type): + if weight_type is None: + scaling = None + elif weight_type == 'gate': + scaling = nn.Linear(dim, 1) + elif weight_type == 'scale': + scaling = nn.Parameter(torch.Tensor(1)) + scaling.data.fill_(1) + elif weight_type == 'scale_kv': + scaling_k = nn.Parameter(torch.Tensor(1)) + scaling_k.data.fill_(1) + scaling_v = nn.Parameter(torch.Tensor(1)) + scaling_v.data.fill_(1) + scaling = (scaling_k, scaling_v) + elif weight_type == 'scale_channel': + scaling = nn.Parameter(torch.Tensor(dim)) + scaling.data.fill_(1) + elif weight_type == 'scale_kv_channel': + scaling_k = nn.Parameter(torch.Tensor(dim)) + scaling_k.data.fill_(1) + scaling_v = nn.Parameter(torch.Tensor(dim)) + scaling_v.data.fill_(1) + scaling = (scaling_k, scaling_v) + elif weight_type and weight_type.startswith('scalar'): + scaling = float(weight_type.split('_')[-1]) + else: + scaling = None + return scaling + + +def apply_data_weight(data, scaling, weight_type): + if weight_type in ['gate']: + scaling = torch.mean(torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1) + elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'): + scaling = scaling + else: + scaling = None + if scaling is not None: + data = data * scaling + return data + + +def detach_tensors(feats): + if type(feats) in [list, tuple]: + feats = [detach_tensors(feat) if feat is not None else None for feat in feats] + elif isinstance(feats, dict): + feats = {key: detach_tensors(val) for key, val in feats.items()} + elif isinstance(feats, torch.Tensor): + feats = feats.detach() + else: + feats = feats.detach() + return feats + + +def probe_tensors(module, feats, name): + feats = detach_tensors(feats) + setattr(module, name, feats) + + +def probe_input_pre_hook(self, args): + input = args[0] + probe_tensors(self, input, 'probe_input_data') + return args + + +def probe_output_hook(self, args, result): + output = result + probe_tensors(self, output, 'probe_output_data') + return output diff --git a/swift/tuners/scetuning/__init__.py b/swift/tuners/scetuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73607de8c4a1deddca575468a278fa75d32e979e --- /dev/null +++ b/swift/tuners/scetuning/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .scetuning import SCETuning, SCETuningConfig diff --git a/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc b/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e5dc98d94efce8a4e3544841ab9de516a71ec66 Binary files /dev/null and b/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc differ diff --git a/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc b/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf62e3719597b6eb8631f90ec6db4a738bcfb52a Binary files /dev/null and b/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc differ diff --git a/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc b/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94f9a73b5e1f73f663562fa90c4f344054a6c54 Binary files /dev/null and b/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc differ diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..c105cd1baef206f64d0f9ce82333eab1e94f5dfd --- /dev/null +++ b/swift/tuners/scetuning/scetuning.py @@ -0,0 +1,235 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +import types +from dataclasses import dataclass, field +from typing import List, Optional, Union + +import torch +from torch import nn + +from swift.tuners.utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput +from swift.utils import get_logger +from swift.utils.torch_utils import find_sub_module +from .scetuning_components import probe_output_hook + +logger = get_logger() + + +@dataclass +class SCETuningConfig(SwiftConfig): + """ + The configuration class for the SCEdit module. + + 'SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing' by Jiang et al.(2023) + See https://arxiv.org/abs/2312.11392 + + Args: + dims(`Union[List[int], int]`): The dimensions of the hidden states + target_modules(`Union[List[str], str]`): The target module to be replaced, can a regex string + hint_modules(`Union[List[str], str]`): The hint module to be replaced, can a regex string + tuner_mode(`str`): Location of tuner operation. + tuner_op(`str`): Tuner operation. + down_ratio(`float`): The dim down ratio of tuner hidden state. + """ + + dims: Optional[Union[List[int], int]] = field( + default=None, metadata={'help': 'The dimensions of the hidden states'}) + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={'help': 'The target module to be replaced, can be a regex string or name list of full match format'}) + + hint_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={'help': 'The hint modules to be replaced, can be a regex string or name list of full match format'}) + + tuner_mode: str = field( + default='decoder', + metadata={'help': 'Location of tuner operation. The tuner mode choices: encoder, decoder, and identity'}) + + tuner_op: str = field(default='SCEAdapter', metadata={'help': 'The tuner ops choices: SCEAdapter'}) + + down_ratio: float = field(default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'}) + + def __post_init__(self): + from swift.tuners.mapping import SwiftTuners + self.swift_type = SwiftTuners.SCETUNING + + +class SCETuning(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: SCETuningConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `SCETuningConfig`""" + module_keys = [key for key, _ in model.named_modules()] + # 1. Matching the hint module + hint_module_ins_list = [] + if config.hint_modules: + if isinstance(config.hint_modules, list): + for module_key in config.hint_modules: + assert module_key in module_keys + h_module = model.get_submodule(module_key) + logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') + if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(h_module)} may not be supported because of its customized forward') + h_module.register_forward_hook(probe_output_hook, with_kwargs=True) + hint_module_ins_list.append(h_module) + else: + for module_key in module_keys: + if re.fullmatch(config.hint_modules, module_key): + h_module = model.get_submodule(module_key) + logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') + if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(h_module)} may not be supported because of its customized forward') + h_module.register_forward_hook(probe_output_hook, with_kwargs=True) + hint_module_ins_list.append(h_module) + if len(hint_module_ins_list) == 0: + logger.error('Cannot match hint modules') + + def _get_module(module): + if isinstance(module, nn.ModuleList): + module = module[-1] + return _get_module(module) + return module + + # 2. Matching the target module + target_module_ins_list = [] + assert config.target_modules is not None + if isinstance(config.target_modules, list): + for module_key in config.target_modules: + assert module_key in module_keys + t_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') + target_module_ins_list.append(_get_module(t_module)) + else: + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): + t_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') + target_module_ins_list.append(_get_module(t_module)) + if len(target_module_ins_list) == 0: + logger.error('Cannot match target modules') + if len(hint_module_ins_list) > 0 and not len(hint_module_ins_list) == len(target_module_ins_list): + logger.info("Target modules' length should be equal with hint modules.") + assert len(hint_module_ins_list) == len(target_module_ins_list) + if isinstance(config.dims, int): + dims = [config.dims for _ in target_module_ins_list] + else: + assert len(config.dims) == len(target_module_ins_list) + dims = config.dims + + # refactor forward function + def _forward_encoder_mode(self, *args, **kwargs): + args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + args_type = type(args) + if args_type is tuple: + args = args[0] + if hasattr(self, 'hint'): + hint_out = self.hint.probe_output_data + args_main = getattr(self, f'scetuner_{adapter_name}')(args, hint_out) + else: + args_main = getattr(self, f'scetuner_{adapter_name}')(args) + if args_type is tuple: + args_main = (args_main, ) + return args_main + + def _forward_decoder_mode(self, *args, **kwargs): + args_type = type(args) + if args_type is tuple: + args_sub_tuner = args[0] + args_sub_extra = args[1:] + tuner_module = getattr(self, f'scetuner_{adapter_name}') + args_hidden, args_res = torch.split(args_sub_tuner, args_sub_tuner.shape[1] - tuner_module.dim, 1) + if hasattr(self, 'hint'): + hint_out = self.hint.probe_output_data + args_res_new = tuner_module(args_res, hint_out) + else: + args_res_new = tuner_module(args_res) + args_sub_tuner_new = torch.cat([args_hidden, args_res_new], dim=1) + if args_type is tuple: + args_main = (args_sub_tuner_new, *args_sub_extra) + + args_main = getattr(self, f'forward_origin_{adapter_name}')(*args_main, **kwargs) + return args_main + + # 3. inject the tuners + for tuner_id, t_module in enumerate(target_module_ins_list): + setattr(t_module, f'forward_origin_{adapter_name}', getattr(t_module, 'forward')) + if config.tuner_mode in ('encoder', 'identity'): + _forward = _forward_encoder_mode + elif config.tuner_mode == 'decoder': + _forward = _forward_decoder_mode + else: + raise Exception(f'Error tuner_mode: {config.tuner_mode}') + setattr(t_module, 'forward', types.MethodType(_forward, t_module)) + tuner_op = SCETunerModule( + name=config.tuner_op, + adapter_name=adapter_name, + module_key=str(tuner_id), + dim=dims[tuner_id], + tuner_length=int(dims[tuner_id] * config.down_ratio)) + setattr(t_module, f'scetuner_{adapter_name}', tuner_op) + if len(hint_module_ins_list) > 0: + setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key} + return state_dict_new + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'scetuner_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class SCETunerModule(nn.Module, ActivationMixin): + + def __init__(self, + name, + adapter_name, + module_key, + dim, + tuner_length, + tuner_type=None, + tuner_weight=None, + act_layer=nn.GELU, + zero_init_last=True, + use_bias=True): + super(SCETunerModule, self).__init__() + super(nn.Module, self).__init__(module_key) + self.name = name + self.adapter_name = adapter_name + self.dim = dim + if name == 'SCEAdapter': + from .scetuning_components import SCEAdapter + self.tuner_op = SCEAdapter( + dim=dim, + adapter_length=tuner_length, + adapter_type=tuner_type, + adapter_weight=tuner_weight, + act_layer=act_layer) + else: + raise Exception(f'Error tuner op {name}') + self.mark_all_sub_modules_as_plugin() + + def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): + if not self.is_activated(self.adapter_name): + return x + if self.name == 'SCEAdapter': + self.tuner_op.to(x.device) + out = self.tuner_op(x) + else: + raise Exception(f'Error tuner op {self.name}') + return out diff --git a/swift/tuners/scetuning/scetuning_components.py b/swift/tuners/scetuning/scetuning_components.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7b981d15bc394710f504ffb630fd08cb061d75 --- /dev/null +++ b/swift/tuners/scetuning/scetuning_components.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math + +import torch +import torch.nn as nn + +from swift.utils.logger import get_logger + +logger = get_logger() + + +def detach_tensors(feats): + if type(feats) in [list, tuple]: + feats = [detach_tensors(feat) if feat is not None else None for feat in feats] + elif isinstance(feats, dict): + feats = {key: detach_tensors(val) for key, val in feats.items()} + elif isinstance(feats, torch.Tensor): + feats = feats.detach() + else: + feats = feats.detach() + return feats + + +def probe_tensors(module, feats, name): + feats = detach_tensors(feats) + setattr(module, name, feats) + + +def probe_input_pre_hook(self, args): + input = args[0] + probe_tensors(self, input, 'probe_input_data') + return args + + +def probe_output_hook(self, args, result): + output = result + probe_tensors(self, output, 'probe_output_data') + return output + + +def choose_weight_type(weight_type, dim): + if weight_type == 'gate': + scaling = nn.Linear(dim, 1) + elif weight_type == 'scale': + scaling = nn.Parameter(torch.Tensor(1)) + scaling.data.fill_(1) + elif weight_type == 'scale_channel': + scaling = nn.Parameter(torch.Tensor(dim)) + scaling.data.fill_(1) + elif weight_type and weight_type.startswith('scalar'): + scaling = float(weight_type.split('_')[-1]) + else: + scaling = None + return scaling + + +def get_weight_value(weight_type, scaling, x): + if weight_type in ['gate']: + scaling = torch.mean(torch.sigmoid(scaling(x)), dim=1).view(-1, 1, 1) + elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'): + scaling = scaling + else: + scaling = None + return scaling + + +class SCEAdapter(nn.Module): + + def __init__(self, + dim, + adapter_length, + adapter_type=None, + adapter_weight=None, + act_layer=nn.GELU, + zero_init_last=True, + use_bias=True): + super(SCEAdapter, self).__init__() + self.dim = dim + self.adapter_length = adapter_length + self.adapter_type = adapter_type + self.adapter_weight = adapter_weight + self.zero_init_last = zero_init_last + self.ln1 = nn.Linear(dim, adapter_length, bias=use_bias) + self.activate = act_layer() + self.ln2 = nn.Linear(adapter_length, dim, bias=use_bias) + self.init_weights() + self.init_scaling() + + def _zero_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def _kaiming_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + + def init_weights(self): + self._kaiming_init_weights(self.ln1) + if self.zero_init_last: + self._zero_init_weights(self.ln2) + else: + self._kaiming_init_weights(self.ln2) + + def init_scaling(self): + if self.adapter_weight: + self.scaling = choose_weight_type(self.adapter_weight, self.dim) + else: + self.scaling = None + + def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): + if x_shortcut is None: + x_shortcut = x + x_shape = x.shape + if len(x_shape) == 4: + b, d, h, w = x_shape + x = x.permute(0, 2, 3, 1).reshape(b, h * w, d) + out = self.ln2(self.activate(self.ln1(x))) + if self.adapter_weight: + scaling = get_weight_value(self.adapter_weight, self.scaling, out) + out = out * scaling if scaling is not None else out + if len(x_shape) == 4: + b, d, h, w = x_shape + out = out.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + if use_shortcut: + out = x_shortcut + out + return out diff --git a/swift/tuners/side.py b/swift/tuners/side.py new file mode 100644 index 0000000000000000000000000000000000000000..a315bcd3a9527c38d96ac34a9da59cf04e01c91c --- /dev/null +++ b/swift/tuners/side.py @@ -0,0 +1,245 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import re +import types +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import partial +from itertools import repeat +from typing import Union + +import torch +from torch import nn + +from swift.utils.logger import get_logger +from swift.utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class SideConfig(SwiftConfig): + """ + The configuration class for the side module. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Args: + target_modules: The feedforward module to be replaced, in regex format + """ + + dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'}) + + target_modules: str = field( + default=None, metadata={'help': 'The target module to be replaced, in full match format'}) + + side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'}) + + source_hidden_pos: Union[str, int] = field( + default=0, + metadata={ + 'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)' + }) + + target_hidden_pos: Union[str, int] = field( + default=0, + metadata={ + 'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)' + }) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.SIDE + + +class Side(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `SideConfig`""" + module_keys = [key for key, _ in model.named_modules()] + + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): # noqa + tgt_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}') + if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): + raise Exception( + f'Type of {type(tgt_module)} may not be supported because of its customized forward') + + def _forward(self, *args, **kwargs): + args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + + if isinstance(config.source_hidden_pos, int): + x = args[config.source_hidden_pos] + else: + x = kwargs[config.source_hidden_pos] + + x_main = args_main[config.target_hidden_pos] \ + if isinstance(args_main, (tuple, list, dict)) else args_main + out = getattr(self, f'side_{adapter_name}')(x, x_main) + if isinstance(args_main, (tuple, list, dict)): + args_main[config.target_hidden_pos] = out + else: + args_main = out + return args_main + + if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'): + tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys())) + + def forward_seq(self, input, *args, **kwargs): + for idx, module in enumerate(self): + if idx >= len(tgt_module.tgt_module_keys): + continue + input = module(input) + return input + + setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module)) + else: + setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) + tgt_module.forward = types.MethodType(_forward, tgt_module) + side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name) + setattr(tgt_module, f'side_{adapter_name}', side_module) + logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}') + + def state_dict_callback(state_dict, adapter_name, **kwargs): + return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key} + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'side_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class SideModule(nn.Module, ActivationMixin): + """The implementation of vision side-tuning method. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Args: + side_module_name: The name of the additive side networks. + """ + + def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'): + super(SideModule, self).__init__() + super(nn.Module, self).__init__(module_key) + self.adapter_name = adapter_name + + side_module_name = side_module_name.lower() + if side_module_name == 'fcn4': + self.side_net = FCN4(out_dims=dim) + elif side_module_name == 'mlp': + self.side_net = Mlp(dim) + elif side_module_name == 'alexnet': + import torchvision + mm = torchvision.models.alexnet(pretrained=True) + self.side_net = nn.Sequential( + OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()), + ('fc', nn.Linear(9216, dim, bias=False))])) + else: + raise ValueError(f'Unsupported side_module_name: {side_module_name}') + self.alpha = nn.Parameter(torch.tensor(0.0)) + self.mark_all_sub_modules_as_plugin() + + def forward(self, x, x_main): + if not self.is_activated(self.adapter_name): + return x_main + alpha_squashed = torch.sigmoid(self.alpha) + x_side = self.side_net(x) + x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side + return x_out + + +class FCN4(nn.Module): + """The implementation of simple FCN4 network for side network. + """ + + def __init__(self, out_dims=-1, **kwargs): + super(FCN4, self).__init__(**kwargs) + + self.conv1 = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16), + nn.ReLU()) + self.conv2 = nn.Sequential( + nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16), + nn.ReLU()) + self.conv3 = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32), + nn.ReLU()) + self.conv4 = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64), + nn.ReLU()) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + if out_dims > 0: + self.fc = nn.Linear(64, out_dims) + else: + self.fc = None + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.pool(x) + x = x.view(x.size(0), -1) + if self.fc is not None: + x = self.fc(x) + return x + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = tuple(repeat(bias, 2)) + drop_probs = tuple(repeat(drop, 2)) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6a578f7a5e83d89205be7f78c71e5569592dbf --- /dev/null +++ b/swift/tuners/utils.py @@ -0,0 +1,431 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2023-present the HuggingFace Inc. team. + +import hashlib +import os +import shutil +import tempfile +import threading +from dataclasses import asdict, dataclass, field +from types import FunctionType +from typing import Dict, Optional, Union + +import json +import numpy as np +import torch +from modelscope import snapshot_download +from modelscope.hub.utils.utils import get_cache_dir +from packaging import version +from peft.utils import CONFIG_NAME +from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper +from peft.utils import _get_submodules + +from swift.llm import MODEL_ARCH_MAPPING, ModelKeys +from swift.utils import gc_collect +from swift.utils.constants import BIN_EXTENSIONS +from swift.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class SwiftConfig: + + swift_type: str = field(default=None) + + model_key_mapping: Optional[Union[dict, ModelKeys]] = field(default=None) + + @property + def __dict__(self): + return asdict(self) + + def to_dict(self): + return self.__dict__ + + def save_pretrained(self, save_directory, **kwargs): + r""" + This method saves the configuration of your adapter model in a directory. + + Args: + save_directory (`str`): + The directory where the configuration will be saved. + """ + if os.path.isfile(save_directory): + raise AssertionError(f'Provided path ({save_directory}) should be a directory, not a file') + + os.makedirs(save_directory, exist_ok=True) + + output_dict = self.__dict__ + output_dict.update(kwargs) + output_path = os.path.join(save_directory, CONFIG_NAME) + + # save it + with open(output_path, 'w', encoding='utf-8') as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + + Args: + pretrained_model_name_or_path (`str`): + The directory or the hub-id where the configuration is saved. + **kwargs: + Additional keyword arguments passed along to the child class initialization. + """ + if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + else: + try: + model_dir = snapshot_download(pretrained_model_name_or_path, ignore_patterns=BIN_EXTENSIONS) + config_file = os.path.join(model_dir, CONFIG_NAME) + except Exception: + raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'") + + loaded_attributes = cls.from_json_file(config_file) + + from .mapping import SWIFT_MAPPING + assert loaded_attributes.get('swift_type', '') in SWIFT_MAPPING + config = SWIFT_MAPPING[loaded_attributes['swift_type']][0](**kwargs) + + for key, value in loaded_attributes.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + + @classmethod + def from_json_file(cls, path_json_file, **kwargs): + r""" + Loads a configuration file from a json file. + + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file, 'r', encoding='utf-8') as file: + json_object = json.load(file) + + return json_object + + +@dataclass +class SwiftOutput: + """The output class returned by all tuners. + + Args: + model (`torch.nn.Module`): The model wrapped + config (`SwiftConfig`): The swift config instance. + state_dict_callback (`FunctionType`): A callback returned by the tuner + which is used to get the tuner's state dict among the model's state dict. + This callback should receive a state dict, and returns a created state dict. + Examples: + >>> def state_dict_callback(state_dict, adapter_name): + >>> return { + >>> key: value + >>> for key, value in state_dict.items() if adapter_name in key + >>> } + save_callback (`FunctionType`): A callback used to save trained model. + mark_trainable_callback (`FunctionType`): A callback returned by the tuner + which is used to mark the tuner's adapter's parameters to trainable. + This callback should receive a model instance, and returns nothing. + Examples: + >>> def mark_trainable_callback(model): + >>> mark_lora_as_trainable(model, config.bias) + optimizer_group_callback (`FunctionType`): A callback returned the param group cared by the tuner. + load_state_dict_callback (`FunctionType`): A callback called before load_state_dict of the tuner. + load_callback (`FunctionType`): A callback used to load trained model. + """ + model: torch.nn.Module = None + config: SwiftConfig = None + state_dict_callback: FunctionType = None + save_callback: FunctionType = None + mark_trainable_callback: FunctionType = None + optimizer_group_callback: FunctionType = None + load_state_dict_callback: FunctionType = None + load_callback: FunctionType = None + + +class ActivationMixin: + + USE_UNIQUE_THREAD = 'USE_UNIQUE_THREAD' + + REMINEDED = False + + def __init__(self, module_key): + self.module_key = module_key + self._thread_inf: Dict[int, Dict[str, bool]] = {} + self._unique_thread = bool(int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '1'))) + if not self._unique_thread and not ActivationMixin.REMINEDED: + ActivationMixin.REMINEDED = True + logger.warn('Using multiple thread mode, gradient checkpointing is not supported.') + + def mark_all_sub_modules_as_plugin(self: torch.nn.Module): + self.plugin = True + for name, module in self.named_modules(): + if 'base_layer' not in name: + module.plugin = True + + @property + def indent(self): + return 0 if self.unique_thread else threading.get_ident() + + @property + def unique_thread(self): + return self._unique_thread + + def set_activation(self, adapter_name, activate=True): + tid = self.indent + if tid not in self._thread_inf: + self._thread_inf[tid] = {} + self._thread_inf[tid][adapter_name] = activate + + def is_activated(self, adapter_name): + tid = self.indent + return self._thread_inf.get(tid, {}).get(adapter_name, False) + + def get_activated_adapters(self): + return [key for key, value in self._thread_inf.get(self.indent, {}).items() if value] + + +class OffloadHelper: + + def __init__(self): + cache_dir = os.path.join(get_cache_dir(), 'offload_cache') + os.makedirs(cache_dir, exist_ok=True) + tmp_dir = tempfile.TemporaryDirectory(dir=cache_dir) + self.cache_dir = tmp_dir.name + self._tmp_dir = tmp_dir + self.index = {} + + @staticmethod + def offload_weight(weight, weight_name, offload_folder, index=None): + dtype = None + if str(weight.dtype) == 'torch.bfloat16': + weight = weight.view(torch.int16) + dtype = 'bfloat16' + array = weight.cpu().numpy() + tensor_file = os.path.join(offload_folder, f'{weight_name}.dat') + if index is not None: + if dtype is None: + dtype = str(array.dtype) + index[weight_name] = {'dtype': dtype, 'shape': list(array.shape)} + if array.ndim == 0: + array = array[None] + file_array = np.memmap(tensor_file, dtype=array.dtype, mode='w+', shape=array.shape) + file_array[:] = array[:] + file_array.flush() + return index + + @staticmethod + def load_offloaded_weight(weight_file, weight_info): + shape = tuple(weight_info['shape']) + if shape == (): + shape = (1, ) + + dtype = weight_info['dtype'] + if dtype == 'bfloat16': + dtype = 'int16' + + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode='r') + + if len(weight_info['shape']) == 0: + weight = weight[0] + weight = torch.tensor(weight) + if weight_info['dtype'] == 'bfloat16': + weight = weight.view(torch.bfloat16) + + return weight + + def offload_disk(self, module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + sub_folder = os.path.join(self.cache_dir, md5) + os.makedirs(sub_folder, exist_ok=True) + state_dict = module.state_dict() + self.index[md5] = {} + for key, tensor in state_dict.items(): + OffloadHelper.offload_weight(tensor, key, sub_folder, self.index[md5]) + + def load_disk(self, module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + sub_folder = os.path.join(self.cache_dir, md5) + state_dict = {} + for key, value in self.index[md5].items(): + file = os.path.join(sub_folder, f'{key}.dat') + state_dict[key] = OffloadHelper.load_offloaded_weight(file, self.index[md5][key]) + if version.parse(torch.__version__) >= version.parse('2.1.0'): + module.load_state_dict(state_dict, assign=True) + else: + for name, _module in module.named_modules(): + if len(list(_module.modules())) > 1: + continue + + buffers = {} + prefix = name if not name else name + '.' + for sub_name, buffer in _module.named_buffers(): + buffer_cls = type(buffer) + buffers[sub_name] = buffer_cls(state_dict[prefix + sub_name]) + _module._buffers.update(buffers) + params = {} + for sub_name, param in _module.named_parameters(): + param_cls = type(param) + params[sub_name] = param_cls(state_dict[prefix + sub_name], requires_grad=param.requires_grad) + _module._parameters.update(params) + shutil.rmtree(sub_folder, ignore_errors=True) + + +class SwiftAdapter: + + offload_helper = OffloadHelper() + + @staticmethod + def prepare_model(model: torch.nn.Module, config: SwiftConfig, adapter_name: str) -> SwiftOutput: + raise NotImplementedError + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + raise NotImplementedError + + @staticmethod + def save_memory(module: torch.nn.Module, adapter_name: str, module_key: str, activate: bool, offload: str = None): + if not isinstance(module, torch.nn.Module): + return + if activate: + SwiftAdapter.load(module, adapter_name, module_key) + else: + SwiftAdapter.offload(module, adapter_name, module_key, offload=offload) + + @staticmethod + def offload(module: torch.nn.Module, adapter_name, module_key, offload: str): + if not offload: + return + device = next(iter(module.parameters())).device + if hasattr(module, 'origin_device') and module.origin_device != str(device): + return + module.origin_device = str(device) + if offload == 'cpu': + if str(device) != 'cpu': + module.to('cpu') + elif offload == 'meta': + if str(device) != 'meta': + SwiftAdapter.offload_helper.offload_disk(module, adapter_name=adapter_name, module_key=module_key) + module.to('meta') + else: + raise NotImplementedError + gc_collect() + + @staticmethod + def load(module: torch.nn.Module, adapter_name, module_key): + device = next(iter(module.parameters())).device + if not hasattr(module, 'origin_device') or module.origin_device == str(device): + return + if str(device) == 'cpu': + module.to(module.origin_device) + delattr(module, 'origin_device') + elif str(device) == 'meta': + SwiftAdapter.offload_helper.load_disk(module, adapter_name=adapter_name, module_key=module_key) + module.to(module.origin_device) + delattr(module, 'origin_device') + + @classmethod + def get_model_key_mapping(cls, model_type, config) -> ModelKeys: + + if model_type in MODEL_ARCH_MAPPING.keys(): + model_key_mapping = MODEL_ARCH_MAPPING[model_type] + else: + model_key_mapping = config.model_key_mapping + + if model_key_mapping is None: + raise ValueError(f'{model_type} is not defined in MODEL_KEYS_MAPPING, ' + f'please consider pass the information through the config.model_key_mapping') + + if isinstance(model_key_mapping, dict): + model_key_mapping: ModelKeys = ModelKeys(**model_key_mapping) + return model_key_mapping + + @staticmethod + def state_dict_load_hook(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]): + pass + + @staticmethod + def has_additional_modules(): + return True + + +class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper): + + def __init__(self, *args, module_key, **kwargs): + super(ModulesToSaveWrapper, self).__init__(module_key) + super(ActivationMixin, self).__init__(*args, **kwargs) + SwiftAdapter.save_memory(self.original_module, 'original_module', self.module_key, False, offload='cpu') + + @property + def active_adapter(self): + active_adapters = self.get_activated_adapters() + if not active_adapters: + return None + elif len(active_adapters) > 1: + raise ValueError('ModulesToSaveWrapper does not support multiple active adapters') + return active_adapters[0] + + def set_adapter(self, adapter_name: str, offload: str = None): + if adapter_name not in self.modules_to_save: + raise ValueError(f'Adapter {adapter_name} not found in {self.modules_to_save.keys()}') + self.modules_to_save[adapter_name].requires_grad_(True) + self.set_activation(adapter_name, True) + SwiftAdapter.save_memory(self.modules_to_save[adapter_name], adapter_name, self.module_key, True) + SwiftAdapter.save_memory(self.original_module, 'original_module', self.module_key, False, offload=offload) + + def deactivate_adapter(self, adapter_name: str, offload: str = None): + if adapter_name in self.modules_to_save and self.unique_thread: + self.modules_to_save[adapter_name].requires_grad_(False) + self.set_activation(adapter_name, False) + SwiftAdapter.save_memory( + self.modules_to_save[adapter_name], adapter_name, self.module_key, False, offload=offload) + if not self.get_activated_adapters(): + SwiftAdapter.save_memory(self.original_module, 'original_module', self.module_key, True) + + def enable_adapters(self, enabled: bool): + super().enable_adapters(enabled) + if not enabled: + SwiftAdapter.save_memory(self.original_module, 'original_module', self.module_key, False, offload='meta') + else: + SwiftAdapter.save_memory(self.original_module, 'original_module', self.module_key, True) + + +def set_adapter(model, adapter_name, activate, offload): + for module in model.modules(): + if isinstance(module, ModulesToSaveWrapper): + if activate: + module.set_adapter(adapter_name, offload) + else: + module.deactivate_adapter(adapter_name, offload) + + +def set_trainable(model, adapter_name): + key_list = [key for key, _ in model.named_modules()] + for key in key_list: + target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) + if target_module_found: + parent, target, target_name = _get_submodules(model, key) + if isinstance(target, ModulesToSaveWrapper): + target.update(adapter_name) + target.set_adapter(target.active_adapter) + else: + new_module = ModulesToSaveWrapper(target, module_key=key, adapter_name=adapter_name) + new_module.set_adapter(adapter_name) + setattr(parent, target_name, new_module) + + +def swift_to_peft_format(ckpt_dir: str, output_dir: str) -> str: + if 'default' in os.listdir(ckpt_dir): # swift_backend + from swift import Swift + Swift.save_to_peft_format(ckpt_dir, output_dir) + ckpt_dir = output_dir + logger.info(f'Converting the swift format checkpoint to peft format, and saving it to: `{output_dir}`') + else: + logger.info('The format of the checkpoint is already in peft format.') + return ckpt_dir diff --git a/swift/ui/__init__.py b/swift/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3b0163fb48e49cef87c02087e58472af76e74f --- /dev/null +++ b/swift/ui/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .app import webui_main diff --git a/swift/ui/app.py b/swift/ui/app.py new file mode 100644 index 0000000000000000000000000000000000000000..81df06f4ff32cf6e7af990980b6fd1f4a73373cb --- /dev/null +++ b/swift/ui/app.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from functools import partial +from typing import List, Union + +import gradio as gr +from packaging import version +from transformers.utils import strtobool + +import swift +from swift.llm import DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SwiftPipeline, WebUIArguments +from swift.ui.llm_eval.llm_eval import LLMEval +from swift.ui.llm_export.llm_export import LLMExport +from swift.ui.llm_infer.llm_infer import LLMInfer +from swift.ui.llm_train.llm_train import LLMTrain + +locale_dict = { + 'title': { + 'zh': '🚀SWIFT: 轻量级大模型训练推理框架', + 'en': '🚀SWIFT: Scalable lightWeight Infrastructure for Fine-Tuning and Inference' + }, + 'sub_title': { + 'zh': + '请查看 ' + 'SWIFT 文档来查看更多功能,使用SWIFT_UI_LANG=en环境变量来切换英文界面', + 'en': + 'Please check ' + 'SWIFT Documentation for more usages, Use SWIFT_UI_LANG=zh variable to switch to Chinese UI', + }, + 'star_beggar': { + 'zh': + '喜欢SWIFT就动动手指给我们加个star吧🥺 ', + 'en': + 'If you like SWIFT, ' + 'please take a few seconds to star us🥺 ' + }, +} + + +class SwiftWebUI(SwiftPipeline): + + args_class = WebUIArguments + args: args_class + + def run(self): + lang = os.environ.get('SWIFT_UI_LANG') or self.args.lang + share_env = os.environ.get('WEBUI_SHARE') + share = strtobool(share_env) if share_env else self.args.share + server = os.environ.get('WEBUI_SERVER') or self.args.server_name + port_env = os.environ.get('WEBUI_PORT') + port = int(port_env) if port_env else self.args.server_port + LLMTrain.set_lang(lang) + LLMInfer.set_lang(lang) + LLMExport.set_lang(lang) + LLMEval.set_lang(lang) + with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app: + try: + _version = swift.__version__ + except AttributeError: + _version = '' + gr.HTML(f"

{locale_dict['title'][lang]}({_version})

") + gr.HTML(f"

{locale_dict['sub_title'][lang]}

") + with gr.Tabs(): + LLMTrain.build_ui(LLMTrain) + LLMInfer.build_ui(LLMInfer) + LLMExport.build_ui(LLMExport) + LLMEval.build_ui(LLMEval) + + concurrent = {} + if version.parse(gr.__version__) < version.parse('4.0.0'): + concurrent = {'concurrency_count': 5} + app.load( + partial(LLMTrain.update_input_model, arg_cls=RLHFArguments), + inputs=[LLMTrain.element('model')], + outputs=[LLMTrain.element('train_record')] + list(LLMTrain.valid_elements().values())) + app.load( + partial(LLMInfer.update_input_model, arg_cls=DeployArguments, has_record=False), + inputs=[LLMInfer.element('model')], + outputs=list(LLMInfer.valid_elements().values())) + app.load( + partial(LLMExport.update_input_model, arg_cls=ExportArguments, has_record=False), + inputs=[LLMExport.element('model')], + outputs=list(LLMExport.valid_elements().values())) + app.load( + partial(LLMEval.update_input_model, arg_cls=EvalArguments, has_record=False), + inputs=[LLMEval.element('model')], + outputs=list(LLMEval.valid_elements().values())) + app.queue(**concurrent).launch(server_name=server, inbrowser=True, server_port=port, height=800, share=share) + + +def webui_main(args: Union[List[str], WebUIArguments, None] = None): + return SwiftWebUI(args).main() diff --git a/swift/ui/base.py b/swift/ui/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca62a6fef2859964292f15e1bd4ac4fda029bbb --- /dev/null +++ b/swift/ui/base.py @@ -0,0 +1,388 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import dataclasses +import os +import sys +import time +import typing +from collections import OrderedDict +from dataclasses import fields +from datetime import datetime +from functools import wraps +from typing import Any, Dict, List, Type + +import gradio as gr +import json +from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video +from modelscope.hub.utils.utils import get_cache_dir + +from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta + +all_langs = ['zh', 'en'] +builder: Type['BaseUI'] = None +base_builder: Type['BaseUI'] = None + + +def update_data(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + elem_id = kwargs.get('elem_id', None) + self = args[0] + + if builder is not None: + choices = base_builder.choice(elem_id) + if choices: + choices = [str(choice) if choice is not None else None for choice in choices] + kwargs['choices'] = choices + + if not isinstance(self, (Tab, TabItem, Accordion)) and 'interactive' not in kwargs: # noqa + kwargs['interactive'] = True + + if 'is_list' in kwargs: + self.is_list = kwargs.pop('is_list') + + if base_builder and base_builder.default(elem_id) is not None and not kwargs.get('value'): + kwargs['value'] = base_builder.default(elem_id) + + if builder is not None: + if elem_id in builder.locales(builder.lang): + values = builder.locale(elem_id, builder.lang) + if 'info' in values: + kwargs['info'] = values['info'] + if 'value' in values: + kwargs['value'] = values['value'] + if 'label' in values: + kwargs['label'] = values['label'] + if hasattr(builder, 'visible'): + kwargs['visible'] = builder.visible + argument = base_builder.argument(elem_id) + if argument and 'label' in kwargs: + kwargs['label'] = kwargs['label'] + f'({argument})' + + kwargs['elem_classes'] = 'align' + ret = fn(self, **kwargs) + self.constructor_args.update(kwargs) + + if builder is not None: + builder.element_dict[elem_id] = self + return ret + + return wrapper + + +Textbox.__init__ = update_data(Textbox.__init__) +Dropdown.__init__ = update_data(Dropdown.__init__) +Checkbox.__init__ = update_data(Checkbox.__init__) +Slider.__init__ = update_data(Slider.__init__) +TabItem.__init__ = update_data(TabItem.__init__) +Accordion.__init__ = update_data(Accordion.__init__) +Button.__init__ = update_data(Button.__init__) +File.__init__ = update_data(File.__init__) +Image.__init__ = update_data(Image.__init__) +Video.__init__ = update_data(Video.__init__) +Audio.__init__ = update_data(Audio.__init__) + + +class BaseUI: + + choice_dict: Dict[str, List] = {} + default_dict: Dict[str, Any] = {} + locale_dict: Dict[str, Dict] = {} + element_dict: Dict[str, Dict] = {} + arguments: Dict[str, str] = {} + sub_ui: List[Type['BaseUI']] = [] + group: str = None + lang: str = all_langs[0] + int_regex = r'^[-+]?[0-9]+$' + float_regex = r'[-+]?(?:\d*\.*\d+)' + bool_regex = r'^(T|t)rue$|^(F|f)alse$' + cache_dir = os.path.join(get_cache_dir(), 'swift-web-ui') + os.makedirs(cache_dir, exist_ok=True) + quote = '\'' if sys.platform != 'win32' else '"' + visible = True + _locale = { + 'local_dir_alert': { + 'value': { + 'zh': '无法识别model_type和template,请手动选择', + 'en': 'Cannot recognize the model_type and template, please choose manually' + } + }, + } + + @classmethod + def build_ui(cls, base_tab: Type['BaseUI']): + """Build UI""" + global builder, base_builder + cls.element_dict = {} + old_builder = builder + old_base_builder = base_builder + builder = cls + base_builder = base_tab + cls.do_build_ui(base_tab) + builder = old_builder + base_builder = old_base_builder + if cls is base_tab: + for ui in cls.sub_ui: + ui.after_build_ui(base_tab) + + @classmethod + def after_build_ui(cls, base_tab: Type['BaseUI']): + pass + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + """Build UI""" + pass + + @classmethod + def save_cache(cls, key, value): + timestamp = str(int(time.time())) + key = key.replace('/', '-') + filename = os.path.join(cls.cache_dir, key + '-' + timestamp) + with open(filename, 'w', encoding='utf-8') as f: + json.dump(value, f) + + @classmethod + def list_cache(cls, key): + files = [] + key = key.replace('/', '-') + for _, _, filenames in os.walk(cls.cache_dir): + for filename in filenames: + if filename.startswith(key): + idx = filename.rfind('-') + key, ts = filename[:idx], filename[idx + 1:] + dt_object = datetime.fromtimestamp(int(ts)) + formatted_time = dt_object.strftime('%Y/%m/%d %H:%M:%S') + files.append(formatted_time) + return sorted(files, reverse=True) + + @classmethod + def load_cache(cls, key, timestamp) -> BaseArguments: + dt_object = datetime.strptime(timestamp, '%Y/%m/%d %H:%M:%S') + timestamp = int(dt_object.timestamp()) + key = key.replace('/', '-') + filename = key + '-' + str(timestamp) + with open(os.path.join(cls.cache_dir, filename), 'r', encoding='utf-8') as f: + return json.load(f) + + @classmethod + def clear_cache(cls, key): + key = key.replace('/', '-') + for _, _, filenames in os.walk(cls.cache_dir): + for filename in filenames: + if filename.startswith(key): + os.remove(os.path.join(cls.cache_dir, filename)) + + @classmethod + def choice(cls, elem_id): + """Get choice by elem_id""" + for sub_ui in BaseUI.sub_ui: + _choice = sub_ui.choice(elem_id) + if _choice: + return _choice + return cls.choice_dict.get(elem_id, []) + + @classmethod + def default(cls, elem_id): + """Get choice by elem_id""" + if elem_id in cls.default_dict: + return cls.default_dict.get(elem_id) + for sub_ui in BaseUI.sub_ui: + _choice = sub_ui.default(elem_id) + if _choice: + return _choice + return None + + @classmethod + def locale(cls, elem_id, lang): + """Get locale by elem_id""" + return cls.locales(lang)[elem_id] + + @classmethod + def locales(cls, lang): + """Get locale by lang""" + locales = OrderedDict() + for sub_ui in cls.sub_ui: + _locales = sub_ui.locales(lang) + locales.update(_locales) + for key, value in cls.locale_dict.items(): + locales[key] = {k: v[lang] for k, v in value.items()} + return locales + + @classmethod + def elements(cls): + """Get all elements""" + elements = OrderedDict() + elements.update(cls.element_dict) + for sub_ui in cls.sub_ui: + _elements = sub_ui.elements() + elements.update(_elements) + return elements + + @classmethod + def valid_elements(cls): + valid_elements = OrderedDict() + elements = cls.elements() + for key, value in elements.items(): + if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record': + valid_elements[key] = value + return valid_elements + + @classmethod + def element_keys(cls): + return list(cls.elements().keys()) + + @classmethod + def valid_element_keys(cls): + return [ + key for key, value in cls.elements().items() + if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record' + ] + + @classmethod + def element(cls, elem_id): + """Get element by elem_id""" + elements = cls.elements() + return elements[elem_id] + + @classmethod + def argument(cls, elem_id): + """Get argument by elem_id""" + return cls.arguments.get(elem_id) + + @classmethod + def set_lang(cls, lang): + cls.lang = lang + for sub_ui in cls.sub_ui: + sub_ui.lang = lang + + @staticmethod + def get_choices_from_dataclass(dataclass): + choice_dict = {} + for f in fields(dataclass): + default_value = f.default + if 'MISSING_TYPE' in str(default_value): + default_value = None + if 'choices' in f.metadata: + choice_dict[f.name] = list(f.metadata['choices']) + if 'Literal' in str(f.type) and typing.get_args(f.type): + choice_dict[f.name] = list(typing.get_args(f.type)) + if f.name in choice_dict and default_value not in choice_dict[f.name]: + choice_dict[f.name].insert(0, default_value) + return choice_dict + + @staticmethod + def get_default_value_from_dataclass(dataclass): + default_dict = {} + for f in fields(dataclass): + if f.default.__class__ is dataclasses._MISSING_TYPE: + default_dict[f.name] = f.default_factory() + else: + default_dict[f.name] = f.default + if isinstance(default_dict[f.name], list): + try: + default_dict[f.name] = ' '.join(default_dict[f.name]) + except TypeError: + default_dict[f.name] = None + if not default_dict[f.name]: + default_dict[f.name] = None + return default_dict + + @staticmethod + def get_argument_names(dataclass): + arguments = {} + for f in fields(dataclass): + arguments[f.name] = f'--{f.name}' + return arguments + + @classmethod + def update_input_model(cls, model, allow_keys=None, has_record=True, arg_cls=BaseArguments, is_ref_model=False): + keys = cls.valid_element_keys() + if allow_keys: + keys = [key for key in keys if key in allow_keys] + + if not model: + ret = [gr.update()] * (len(keys) + int(has_record)) + if len(ret) == 1: + return ret[0] + else: + return ret + + model_meta = get_matched_model_meta(model) + local_args_path = os.path.join(model, 'args.json') + if model_meta is None and not os.path.exists(local_args_path): + gr.Info(cls._locale['local_dir_alert']['value'][cls.lang]) + ret = [gr.update()] * (len(keys) + int(has_record)) + if len(ret) == 1: + return ret[0] + else: + return ret + + if os.path.exists(local_args_path): + try: + if hasattr(arg_cls, 'resume_from_checkpoint'): + try: + args = arg_cls(resume_from_checkpoint=model, load_data_args=True) + except Exception as e: + if 'using `--model`' in str(e): # TODO a dirty fix + args = arg_cls(model=model, load_data_args=True) + else: + raise e + else: + args = arg_cls(ckpt_dir=model, load_data_args=True) + except ValueError: + return [gr.update()] * (len(keys) + int(has_record)) + values = [] + for key in keys: + arg_value = getattr(args, key, None) + if arg_value and key != 'model': + if key in ('torch_dtype', 'bnb_4bit_compute_dtype'): + arg_value = str(arg_value).split('.')[1] + if isinstance(arg_value, list) and key != 'dataset': + try: + arg_value = ' '.join(arg_value) + except Exception: + arg_value = None + values.append(gr.update(value=arg_value)) + else: + values.append(gr.update()) + ret = [gr.update(choices=[])] * int(has_record) + values + if len(ret) == 1: + return ret[0] + else: + return ret + else: + values = [] + for key in keys: + if key not in ('template', 'model_type', 'ref_model_type', 'system'): + values.append(gr.update()) + elif key in ('template', 'model_type', 'ref_model_type'): + if key == 'ref_model_type': + if is_ref_model: + values.append(gr.update(value=getattr(model_meta, 'model_type'))) + else: + values.append(gr.update()) + else: + values.append(gr.update(value=getattr(model_meta, key))) + else: + values.append(gr.update(value=TEMPLATE_MAPPING[model_meta.template].default_system)) + + if has_record: + return [gr.update(choices=cls.list_cache(model))] + values + else: + if len(values) == 1: + return values[0] + return values + + @classmethod + def update_all_settings(cls, model, train_record, base_tab): + if not train_record: + return [gr.update()] * len(cls.elements()) + cache = cls.load_cache(model, train_record) + updates = [] + for key, value in base_tab.valid_elements().items(): + if key in cache: + updates.append(gr.update(value=cache[key])) + else: + updates.append(gr.update()) + return updates diff --git a/swift/ui/llm_eval/__init__.py b/swift/ui/llm_eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/swift/ui/llm_eval/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/swift/ui/llm_eval/eval.py b/swift/ui/llm_eval/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ded9038bbae8d7a25e1bc2085bf74459fde787b5 --- /dev/null +++ b/swift/ui/llm_eval/eval.py @@ -0,0 +1,130 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.ui.base import BaseUI +from swift.utils import get_logger + +logger = get_logger() + + +class Eval(BaseUI): + + group = 'llm_eval' + + locale_dict = { + 'eval_backend': { + 'label': { + 'zh': '评测后端', + 'en': 'Eval backend' + }, + 'info': { + 'zh': '选择评测后端', + 'en': 'Select eval backend' + } + }, + 'eval_dataset': { + 'label': { + 'zh': '评测数据集', + 'en': 'Evaluation dataset' + }, + 'info': { + 'zh': '选择评测数据集,支持多选 (先选择评测后端)', + 'en': 'Select eval dataset, multiple datasets supported (select eval backend first)' + } + }, + 'eval_limit': { + 'label': { + 'zh': '评测数据个数', + 'en': 'Eval numbers for each dataset' + }, + 'info': { + 'zh': '每个评测集的取样数', + 'en': 'Number of rows sampled from each dataset' + } + }, + 'eval_output_dir': { + 'label': { + 'zh': '评测输出目录', + 'en': 'Eval output dir' + }, + 'info': { + 'zh': '评测结果的输出目录', + 'en': 'The dir to save the eval results' + } + }, + 'custom_eval_config': { + 'label': { + 'zh': '自定义数据集评测配置', + 'en': 'Custom eval config' + }, + 'info': { + 'zh': '可以使用该配置评测自己的数据集,详见github文档的评测部分', + 'en': 'Use this config to eval your own datasets, check the docs in github for details' + } + }, + 'eval_url': { + 'label': { + 'zh': '评测链接', + 'en': 'The eval url' + }, + 'info': { + 'zh': + 'OpenAI样式的评测链接(如:http://localhost:8080/v1/chat/completions),用于评测接口(模型类型输入为实际模型类型)', + 'en': + 'The OpenAI style link(like: http://localhost:8080/v1/chat/completions) for ' + 'evaluation(Input actual model type into model_type)' + } + }, + 'api_key': { + 'label': { + 'zh': '接口token', + 'en': 'The url token' + }, + 'info': { + 'zh': 'eval_url的token', + 'en': 'The token used with eval_url' + } + }, + 'infer_backend': { + 'label': { + 'zh': '推理框架', + 'en': 'Infer backend' + }, + } + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + try: + from swift.llm.argument.eval_args import EvalArguments + eval_dataset_dict = EvalArguments.list_eval_dataset() + default_backend = EvalArguments.eval_backend + except Exception as e: + logger.warn(e) + eval_dataset_dict = {} + default_backend = None + + with gr.Row(): + gr.Dropdown(elem_id='eval_backend', choices=list(eval_dataset_dict.keys()), value=default_backend, scale=20) + gr.Dropdown( + elem_id='eval_dataset', + is_list=True, + choices=eval_dataset_dict.get(default_backend, []), + multiselect=True, + allow_custom_value=True, + scale=20) + gr.Textbox(elem_id='eval_limit', scale=20) + gr.Dropdown(elem_id='infer_backend', scale=20) + with gr.Row(): + gr.Textbox(elem_id='custom_eval_config', scale=20) + gr.Textbox(elem_id='eval_output_dir', scale=20) + gr.Textbox(elem_id='eval_url', scale=20) + gr.Textbox(elem_id='api_key', scale=20) + + def update_eval_dataset(backend): + return gr.update(choices=eval_dataset_dict[backend]) + + cls.element('eval_backend').change(update_eval_dataset, [cls.element('eval_backend')], + [cls.element('eval_dataset')]) diff --git a/swift/ui/llm_eval/llm_eval.py b/swift/ui/llm_eval/llm_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..05824f1904756ca393678fed74957383665755b4 --- /dev/null +++ b/swift/ui/llm_eval/llm_eval.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import re +import sys +import time +from datetime import datetime +from functools import partial +from typing import Type + +import gradio as gr +import json +import torch +from json import JSONDecodeError +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +from swift.llm import EvalArguments +from swift.ui.base import BaseUI +from swift.ui.llm_eval.eval import Eval +from swift.ui.llm_eval.model import Model +from swift.ui.llm_eval.runtime import EvalRuntime +from swift.utils import get_device_count + + +class LLMEval(BaseUI): + group = 'llm_eval' + + sub_ui = [Model, Eval, EvalRuntime] + + cmd = 'eval' + + locale_dict = { + 'llm_eval': { + 'label': { + 'zh': 'LLM评测', + 'en': 'LLM evaluation', + } + }, + 'more_params': { + 'label': { + 'zh': '更多参数', + 'en': 'More params' + }, + 'info': { + 'zh': '以json格式或--xxx xxx命令行格式填入', + 'en': 'Fill in with json format or --xxx xxx cmd format' + } + }, + 'evaluate': { + 'value': { + 'zh': '开始评测', + 'en': 'Begin Evaluation' + }, + }, + 'gpu_id': { + 'label': { + 'zh': '选择可用GPU', + 'en': 'Choose GPU' + }, + 'info': { + 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU', + 'en': 'Select GPU to train' + } + }, + } + + choice_dict = BaseUI.get_choices_from_dataclass(EvalArguments) + default_dict = BaseUI.get_default_value_from_dataclass(EvalArguments) + arguments = BaseUI.get_argument_names(EvalArguments) + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.TabItem(elem_id='llm_eval', label=''): + default_device = 'cpu' + device_count = get_device_count() + if device_count > 0: + default_device = '0' + with gr.Blocks(): + Model.build_ui(base_tab) + Eval.build_ui(base_tab) + EvalRuntime.build_ui(base_tab) + with gr.Row(): + gr.Textbox(elem_id='more_params', lines=4, scale=20) + gr.Button(elem_id='evaluate', scale=2, variant='primary') + gr.Dropdown( + elem_id='gpu_id', + multiselect=True, + choices=[str(i) for i in range(device_count)] + ['cpu'], + value=default_device, + scale=8) + + cls.element('evaluate').click( + cls.eval_model, list(base_tab.valid_elements().values()), + [cls.element('runtime_tab'), cls.element('running_tasks')]) + + base_tab.element('running_tasks').change( + partial(EvalRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], + list(base_tab.valid_elements().values()) + [cls.element('log')]) + EvalRuntime.element('kill_task').click( + EvalRuntime.kill_task, + [EvalRuntime.element('running_tasks')], + [EvalRuntime.element('running_tasks')] + [EvalRuntime.element('log')], + ) + + @classmethod + def eval(cls, *args): + eval_args = cls.get_default_value_from_dataclass(EvalArguments) + kwargs = {} + kwargs_is_list = {} + other_kwargs = {} + more_params = {} + more_params_cmd = '' + keys = cls.valid_element_keys() + for key, value in zip(keys, args): + compare_value = eval_args.get(key) + compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value + compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value + if key in eval_args and compare_value_ui != compare_value_arg and value: + if isinstance(value, str) and re.fullmatch(cls.int_regex, value): + value = int(value) + elif isinstance(value, str) and re.fullmatch(cls.float_regex, value): + value = float(value) + elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value): + value = True if value.lower() == 'true' else False + kwargs[key] = value if not isinstance(value, list) else ' '.join(value) + kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False) + else: + other_kwargs[key] = value + if key == 'more_params' and value: + try: + more_params = json.loads(value) + except (JSONDecodeError or TypeError): + more_params_cmd = value + + kwargs.update(more_params) + model = kwargs.get('model') + if model and os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')): + kwargs['ckpt_dir'] = kwargs.pop('model') + + eval_args = EvalArguments( + **{ + key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value + for key, value in kwargs.items() + }) + params = '' + sep = f'{cls.quote} {cls.quote}' + for e in kwargs: + if isinstance(kwargs[e], list): + params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} ' + elif e in kwargs_is_list and kwargs_is_list[e]: + all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()] + params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} ' + else: + params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} ' + params += more_params_cmd + ' ' + devices = other_kwargs['gpu_id'] + devices = [d for d in devices if d] + assert (len(devices) == 1 or 'cpu' not in devices) + gpus = ','.join(devices) + cuda_param = '' + if gpus != 'cpu': + if is_torch_npu_available(): + cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}' + elif is_torch_cuda_available(): + cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}' + else: + cuda_param = '' + now = datetime.now() + time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}' + file_path = f'output/{eval_args.model_type}-{time_str}' + if not os.path.exists(file_path): + os.makedirs(file_path, exist_ok=True) + log_file = os.path.join(os.getcwd(), f'{file_path}/run_eval.log') + eval_args.log_file = log_file + params += f'--log_file "{log_file}" ' + params += '--ignore_args_error true ' + if sys.platform == 'win32': + if cuda_param: + cuda_param = f'set {cuda_param} && ' + run_command = f'{cuda_param}start /b swift eval {params} > {log_file} 2>&1' + else: + run_command = f'{cuda_param} nohup swift eval {params} > {log_file} 2>&1 &' + return run_command, eval_args, log_file + + @classmethod + def eval_model(cls, *args): + run_command, eval_args, log_file = cls.eval(*args) + os.system(run_command) + time.sleep(2) + return gr.update(open=True), EvalRuntime.refresh_tasks(log_file) diff --git a/swift/ui/llm_eval/model.py b/swift/ui/llm_eval/model.py new file mode 100644 index 0000000000000000000000000000000000000000..570afabf8c63d37a3d1487a97d2591102b93eefd --- /dev/null +++ b/swift/ui/llm_eval/model.py @@ -0,0 +1,78 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from typing import Type + +import gradio as gr + +from swift.llm import TEMPLATE_MAPPING, EvalArguments, ModelType +from swift.llm.model.register import get_all_models +from swift.ui.base import BaseUI + + +class Model(BaseUI): + + group = 'llm_eval' + + locale_dict = { + 'checkpoint': { + 'value': { + 'zh': '训练后的模型', + 'en': 'Trained model' + } + }, + 'model_type': { + 'label': { + 'zh': '选择模型类型', + 'en': 'Select Model Type' + }, + 'info': { + 'zh': 'SWIFT已支持的模型类型', + 'en': 'Base model type supported by SWIFT' + } + }, + 'model': { + 'label': { + 'zh': '模型id或路径', + 'en': 'Model id or path' + }, + 'info': { + 'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录', + 'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir' + } + }, + 'reset': { + 'value': { + 'zh': '恢复初始值', + 'en': 'Reset to default' + }, + }, + 'template': { + 'label': { + 'zh': '模型Prompt模板类型', + 'en': 'Prompt template type' + }, + 'info': { + 'zh': '选择匹配模型的Prompt模板', + 'en': 'Choose the template type of the model' + } + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Row(): + gr.Dropdown( + elem_id='model', + scale=20, + choices=get_all_models(), + value='Qwen/Qwen2.5-7B-Instruct', + allow_custom_value=True) + gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20) + gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20) + + @classmethod + def after_build_ui(cls, base_tab: Type['BaseUI']): + cls.element('model').change( + partial(cls.update_input_model, arg_cls=EvalArguments, has_record=False), + inputs=[cls.element('model')], + outputs=list(cls.valid_elements().values())) diff --git a/swift/ui/llm_eval/runtime.py b/swift/ui/llm_eval/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..03c90b81b0dfd454562a9ed1786ef224e0f0c3ce --- /dev/null +++ b/swift/ui/llm_eval/runtime.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr +from packaging import version + +from swift.ui.base import BaseUI +from swift.ui.llm_infer.runtime import Runtime +from swift.utils import get_logger + +logger = get_logger() + + +class EvalRuntime(Runtime): + + group = 'llm_eval' + + cmd = 'eval' + + locale_dict = { + 'runtime_tab': { + 'label': { + 'zh': '运行时', + 'en': 'Runtime' + }, + }, + 'running_cmd': { + 'label': { + 'zh': '运行命令', + 'en': 'Command line' + }, + 'info': { + 'zh': '执行的实际命令', + 'en': 'The actual command' + } + }, + 'show_log': { + 'value': { + 'zh': '展示评测状态', + 'en': 'Show eval status' + }, + }, + 'stop_show_log': { + 'value': { + 'zh': '停止展示', + 'en': 'Stop showing running status' + }, + }, + 'log': { + 'label': { + 'zh': '日志输出', + 'en': 'Logging content' + }, + 'info': { + 'zh': '如果日志无更新请再次点击"展示日志内容"', + 'en': 'Please press "Show log" if the log content is not updating' + } + }, + 'running_tasks': { + 'label': { + 'zh': '运行中评测', + 'en': 'Running evaluation' + }, + 'info': { + 'zh': '所有的swift eval命令启动的任务', + 'en': 'All tasks started by swift eval' + } + }, + 'refresh_tasks': { + 'value': { + 'zh': '找回评测', + 'en': 'Find evaluation' + }, + }, + 'kill_task': { + 'value': { + 'zh': '杀死评测', + 'en': 'Kill evaluation' + }, + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Accordion(elem_id='runtime_tab', open=False, visible=True): + with gr.Blocks(): + with gr.Row(): + gr.Dropdown(elem_id='running_tasks', scale=10) + gr.Button(elem_id='refresh_tasks', scale=1, variant='primary') + gr.Button(elem_id='show_log', scale=1, variant='primary') + gr.Button(elem_id='stop_show_log', scale=1) + gr.Button(elem_id='kill_task', scale=1, size='lg') + with gr.Row(): + gr.Textbox(elem_id='log', lines=6, visible=False) + + concurrency_limit = {} + if version.parse(gr.__version__) >= version.parse('4.0.0'): + concurrency_limit = {'concurrency_limit': 5} + cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then( + cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit) + + base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) + + base_tab.element('refresh_tasks').click( + cls.refresh_tasks, + [base_tab.element('running_tasks')], + [base_tab.element('running_tasks')], + ) diff --git a/swift/ui/llm_export/__init__.py b/swift/ui/llm_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/swift/ui/llm_export/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/swift/ui/llm_export/export.py b/swift/ui/llm_export/export.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4ee80c3bbefcbcb4b232fa146a25f9857b5169 --- /dev/null +++ b/swift/ui/llm_export/export.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.llm.dataset.register import get_dataset_list +from swift.ui.base import BaseUI + + +class Export(BaseUI): + + group = 'llm_export' + + locale_dict = { + 'merge_lora': { + 'label': { + 'zh': '合并lora', + 'en': 'Merge lora' + }, + 'info': { + 'zh': + 'lora合并的路径在填入的checkpoint同级目录,请查看运行时log获取更具体的信息', + 'en': + 'The output path is in the sibling directory as the input checkpoint. ' + 'Please refer to the runtime log for more specific information.' + }, + }, + 'device_map': { + 'label': { + 'zh': '合并lora使用的device_map', + 'en': 'The device_map when merge-lora' + }, + 'info': { + 'zh': '如果显存不够请填入cpu', + 'en': 'If GPU memory is not enough, fill in cpu' + }, + }, + 'quant_bits': { + 'label': { + 'zh': '量化比特数', + 'en': 'Quantize bits' + }, + }, + 'quant_method': { + 'label': { + 'zh': '量化方法', + 'en': 'Quantize method' + }, + }, + 'quant_n_samples': { + 'label': { + 'zh': '量化集采样数', + 'en': 'Sampled rows from calibration dataset' + }, + }, + 'max_length': { + 'label': { + 'zh': '量化集的max-length', + 'en': 'The quantize sequence length' + }, + }, + 'output_dir': { + 'label': { + 'zh': '输出路径', + 'en': 'Output dir' + }, + }, + 'dataset': { + 'label': { + 'zh': '校准数据集', + 'en': 'Calibration datasets' + }, + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Row(): + gr.Checkbox(elem_id='merge_lora', scale=10) + gr.Textbox(elem_id='device_map', scale=20) + with gr.Row(): + gr.Dropdown(elem_id='quant_bits', scale=20) + gr.Dropdown(elem_id='quant_method', scale=20) + gr.Textbox(elem_id='quant_n_samples', scale=20) + gr.Textbox(elem_id='max_length', scale=20) + with gr.Row(): + gr.Textbox(elem_id='output_dir', scale=20) + gr.Dropdown( + elem_id='dataset', multiselect=True, allow_custom_value=True, choices=get_dataset_list(), scale=20) diff --git a/swift/ui/llm_export/llm_export.py b/swift/ui/llm_export/llm_export.py new file mode 100644 index 0000000000000000000000000000000000000000..b71ccf6d7f3d12cf5cd279bff716d2b9557a4373 --- /dev/null +++ b/swift/ui/llm_export/llm_export.py @@ -0,0 +1,191 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import re +import sys +import time +from datetime import datetime +from functools import partial +from typing import Type + +import gradio as gr +import json +import torch +from json import JSONDecodeError +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +from swift.llm import ExportArguments +from swift.ui.base import BaseUI +from swift.ui.llm_export.export import Export +from swift.ui.llm_export.model import Model +from swift.ui.llm_export.runtime import ExportRuntime +from swift.utils import get_device_count + + +class LLMExport(BaseUI): + group = 'llm_export' + + sub_ui = [Model, Export, ExportRuntime] + + locale_dict = { + 'llm_export': { + 'label': { + 'zh': 'LLM导出', + 'en': 'LLM export', + } + }, + 'more_params': { + 'label': { + 'zh': '更多参数', + 'en': 'More params' + }, + 'info': { + 'zh': '以json格式或--xxx xxx命令行格式填入', + 'en': 'Fill in with json format or --xxx xxx cmd format' + } + }, + 'export': { + 'value': { + 'zh': '开始导出', + 'en': 'Begin Export' + }, + }, + 'gpu_id': { + 'label': { + 'zh': '选择可用GPU', + 'en': 'Choose GPU' + }, + 'info': { + 'zh': '选择使用的GPU号,如CUDA不可用只能选择CPU', + 'en': 'Select GPU to export' + } + }, + } + + choice_dict = BaseUI.get_choices_from_dataclass(ExportArguments) + default_dict = BaseUI.get_default_value_from_dataclass(ExportArguments) + arguments = BaseUI.get_argument_names(ExportArguments) + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.TabItem(elem_id='llm_export', label=''): + default_device = 'cpu' + device_count = get_device_count() + if device_count > 0: + default_device = '0' + with gr.Blocks(): + Model.build_ui(base_tab) + Export.build_ui(base_tab) + ExportRuntime.build_ui(base_tab) + with gr.Row(): + gr.Textbox(elem_id='more_params', lines=4, scale=20) + gr.Button(elem_id='export', scale=2, variant='primary') + gr.Dropdown( + elem_id='gpu_id', + multiselect=True, + choices=[str(i) for i in range(device_count)] + ['cpu'], + value=default_device, + scale=8) + + cls.element('export').click( + cls.export_model, list(base_tab.valid_elements().values()), + [cls.element('runtime_tab'), cls.element('running_tasks')]) + + base_tab.element('running_tasks').change( + partial(ExportRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], + list(base_tab.valid_elements().values()) + [cls.element('log')]) + ExportRuntime.element('kill_task').click( + ExportRuntime.kill_task, + [ExportRuntime.element('running_tasks')], + [ExportRuntime.element('running_tasks')] + [ExportRuntime.element('log')], + ) + + @classmethod + def export(cls, *args): + export_args = cls.get_default_value_from_dataclass(ExportArguments) + kwargs = {} + kwargs_is_list = {} + other_kwargs = {} + more_params = {} + more_params_cmd = '' + keys = cls.valid_element_keys() + for key, value in zip(keys, args): + compare_value = export_args.get(key) + compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value + compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value + if key in export_args and compare_value_ui != compare_value_arg and value: + if isinstance(value, str) and re.fullmatch(cls.int_regex, value): + value = int(value) + elif isinstance(value, str) and re.fullmatch(cls.float_regex, value): + value = float(value) + elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value): + value = True if value.lower() == 'true' else False + kwargs[key] = value if not isinstance(value, list) else ' '.join(value) + kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False) + else: + other_kwargs[key] = value + if key == 'more_params' and value: + try: + more_params = json.loads(value) + except (JSONDecodeError or TypeError): + more_params_cmd = value + + kwargs.update(more_params) + model = kwargs.get('model') + if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')): + kwargs['ckpt_dir'] = kwargs.pop('model') + export_args = ExportArguments( + **{ + key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value + for key, value in kwargs.items() + }) + params = '' + sep = f'{cls.quote} {cls.quote}' + for e in kwargs: + if isinstance(kwargs[e], list): + params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} ' + elif e in kwargs_is_list and kwargs_is_list[e]: + all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()] + params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} ' + else: + params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} ' + params += more_params_cmd + ' ' + devices = other_kwargs['gpu_id'] + devices = [d for d in devices if d] + assert (len(devices) == 1 or 'cpu' not in devices) + gpus = ','.join(devices) + cuda_param = '' + if gpus != 'cpu': + if is_torch_npu_available(): + cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}' + elif is_torch_cuda_available(): + cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}' + else: + cuda_param = '' + now = datetime.now() + time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}' + file_path = f'output/{export_args.model_type}-{time_str}' + if not os.path.exists(file_path): + os.makedirs(file_path, exist_ok=True) + log_file = os.path.join(os.getcwd(), f'{file_path}/run_export.log') + export_args.log_file = log_file + params += f'--log_file "{log_file}" ' + params += '--ignore_args_error true ' + additional_param = '' + if export_args.quant_method == 'gptq': + additional_param = 'OMP_NUM_THREADS=14' + if sys.platform == 'win32': + if cuda_param: + cuda_param = f'set {cuda_param} && ' + if additional_param: + additional_param = f'set {additional_param} && ' + run_command = f'{cuda_param}{additional_param}start /b swift export {params} > {log_file} 2>&1' + else: + run_command = f'{cuda_param} {additional_param} nohup swift export {params} > {log_file} 2>&1 &' + return run_command, export_args, log_file + + @classmethod + def export_model(cls, *args): + run_command, export_args, log_file = cls.export(*args) + os.system(run_command) + time.sleep(2) + return gr.update(open=True), ExportRuntime.refresh_tasks(log_file) diff --git a/swift/ui/llm_export/model.py b/swift/ui/llm_export/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d42862f71ded65990b2104b7dda4d625a0953544 --- /dev/null +++ b/swift/ui/llm_export/model.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from typing import Type + +import gradio as gr + +from swift.llm import TEMPLATE_MAPPING, ExportArguments, ModelType +from swift.llm.model.register import get_all_models +from swift.ui.base import BaseUI + + +class Model(BaseUI): + + group = 'llm_export' + + locale_dict = { + 'checkpoint': { + 'value': { + 'zh': '训练后的模型', + 'en': 'Trained model' + } + }, + 'model_type': { + 'label': { + 'zh': '选择模型类型', + 'en': 'Select Model Type' + }, + 'info': { + 'zh': 'SWIFT已支持的模型类型', + 'en': 'Base model type supported by SWIFT' + } + }, + 'model': { + 'label': { + 'zh': '模型id或路径', + 'en': 'Model id or path' + }, + 'info': { + 'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录', + 'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir' + } + }, + 'reset': { + 'value': { + 'zh': '恢复初始值', + 'en': 'Reset to default' + }, + }, + 'template': { + 'label': { + 'zh': '模型Prompt模板类型', + 'en': 'Prompt template type' + }, + 'info': { + 'zh': '选择匹配模型的Prompt模板', + 'en': 'Choose the template type of the model' + } + }, + } + + ignored_models = ['int1', 'int2', 'int4', 'int8', 'awq', 'gptq', 'bnb', 'eetq', 'aqlm', 'hqq'] + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Row(): + all_models = [ + model for model in get_all_models() if not any([ignored in model for ignored in cls.ignored_models]) + ] + gr.Dropdown( + elem_id='model', + scale=20, + choices=all_models, + value='Qwen/Qwen2.5-7B-Instruct', + allow_custom_value=True) + gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20) + gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20) + + @classmethod + def after_build_ui(cls, base_tab: Type['BaseUI']): + cls.element('model').change( + partial(cls.update_input_model, arg_cls=ExportArguments, has_record=False), + inputs=[cls.element('model')], + outputs=list(cls.valid_elements().values())) diff --git a/swift/ui/llm_export/runtime.py b/swift/ui/llm_export/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..f34ac4dfb0e917b2a9e1d9c3fdeb635c62315275 --- /dev/null +++ b/swift/ui/llm_export/runtime.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.ui.llm_infer.runtime import Runtime +from swift.utils import get_logger + +logger = get_logger() + + +class ExportRuntime(Runtime): + + group = 'llm_export' + + cmd = 'export' + + locale_dict = { + 'runtime_tab': { + 'label': { + 'zh': '运行时', + 'en': 'Runtime' + }, + }, + 'running_cmd': { + 'label': { + 'zh': '运行命令', + 'en': 'Command line' + }, + 'info': { + 'zh': '执行的实际命令', + 'en': 'The actual command' + } + }, + 'show_log': { + 'value': { + 'zh': '展示导出状态', + 'en': 'Show export status' + }, + }, + 'stop_show_log': { + 'value': { + 'zh': '停止展示', + 'en': 'Stop showing running status' + }, + }, + 'log': { + 'label': { + 'zh': '日志输出', + 'en': 'Logging content' + }, + 'info': { + 'zh': '如果日志无更新请再次点击"展示日志内容"', + 'en': 'Please press "Show log" if the log content is not updating' + } + }, + 'running_tasks': { + 'label': { + 'zh': '运行中导出任务', + 'en': 'Running export task' + }, + 'info': { + 'zh': '所有的swift export命令启动的任务', + 'en': 'All tasks started by swift export' + } + }, + 'refresh_tasks': { + 'value': { + 'zh': '找回导出任务', + 'en': 'Find export' + }, + }, + 'kill_task': { + 'value': { + 'zh': '杀死导出任务', + 'en': 'Kill export' + }, + }, + } diff --git a/swift/ui/llm_infer/__init__.py b/swift/ui/llm_infer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/swift/ui/llm_infer/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/swift/ui/llm_infer/generate.py b/swift/ui/llm_infer/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b83b212a95b2efb2c981522e96641f565bb61f05 --- /dev/null +++ b/swift/ui/llm_infer/generate.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.ui.base import BaseUI + + +class Generate(BaseUI): + + group = 'llm_infer' + + locale_dict = { + 'max_new_tokens': { + 'label': { + 'zh': '生成序列最大长度', + 'en': 'Max new tokens' + }, + }, + 'temperature': { + 'label': { + 'zh': 'temperature', + 'en': 'temperature' + }, + }, + 'top_k': { + 'label': { + 'zh': 'top_k', + 'en': 'top_k' + }, + }, + 'top_p': { + 'label': { + 'zh': 'top_p', + 'en': 'top_p' + }, + }, + 'repetition_penalty': { + 'label': { + 'zh': 'repetition_penalty', + 'en': 'repetition_penalty' + }, + }, + 'system': { + 'label': { + 'zh': 'system字段', + 'en': 'system' + }, + 'info': { + 'zh': 'system字段支持在加载模型后修改', + 'en': 'system can be modified after the model weights loaded' + } + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Row(): + gr.Textbox(elem_id='max_new_tokens', lines=1, value='2048') + gr.Slider(elem_id='temperature', minimum=0.0, maximum=10, step=0.1, value=0.3) + gr.Slider(elem_id='top_k', minimum=1, maximum=100, step=5, value=20) + gr.Slider(elem_id='top_p', minimum=0.0, maximum=1.0, step=0.05, value=0.7) + gr.Slider(elem_id='repetition_penalty', minimum=0.0, maximum=10, step=0.05, value=1.05) + with gr.Row(): + gr.Textbox(elem_id='system', lines=4, scale=20) diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..70480631879158441604e7d9034b8beb048f3181 --- /dev/null +++ b/swift/ui/llm_infer/llm_infer.py @@ -0,0 +1,396 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import re +import signal +import sys +import time +from copy import deepcopy +from datetime import datetime +from functools import partial +from typing import List, Type + +import gradio as gr +import json +import torch +from json import JSONDecodeError +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +from swift.llm import DeployArguments, InferArguments, InferClient, InferRequest, RequestConfig +from swift.ui.base import BaseUI +from swift.ui.llm_infer.model import Model +from swift.ui.llm_infer.runtime import Runtime +from swift.utils import get_device_count, get_logger + +logger = get_logger() + + +class LLMInfer(BaseUI): + + group = 'llm_infer' + + is_multimodal = True + + sub_ui = [Model, Runtime] + + locale_dict = { + 'generate_alert': { + 'value': { + 'zh': '请先部署模型', + 'en': 'Please deploy model first', + } + }, + 'port': { + 'label': { + 'zh': '端口', + 'en': 'port' + }, + }, + 'llm_infer': { + 'label': { + 'zh': 'LLM推理', + 'en': 'LLM Inference', + } + }, + 'load_alert': { + 'value': { + 'zh': '部署中,请点击"展示部署状态"查看', + 'en': 'Start to deploy model, ' + 'please Click "Show running ' + 'status" to view details', + } + }, + 'loaded_alert': { + 'value': { + 'zh': '模型加载完成', + 'en': 'Model loaded' + } + }, + 'port_alert': { + 'value': { + 'zh': '该端口已被占用', + 'en': 'The port has been occupied' + } + }, + 'chatbot': { + 'value': { + 'zh': '对话框', + 'en': 'Chat bot' + }, + }, + 'infer_model_type': { + 'label': { + 'zh': 'Lora模块', + 'en': 'Lora module' + }, + 'info': { + 'zh': '发送给server端哪个LoRA,默认为`default`', + 'en': 'Which LoRA to use on server, default value is `default`' + } + }, + 'prompt': { + 'label': { + 'zh': '请输入:', + 'en': 'Input:' + }, + }, + 'clear_history': { + 'value': { + 'zh': '清除对话信息', + 'en': 'Clear history' + }, + }, + 'submit': { + 'value': { + 'zh': '🚀 发送', + 'en': '🚀 Send' + }, + }, + 'gpu_id': { + 'label': { + 'zh': '选择可用GPU', + 'en': 'Choose GPU' + }, + 'info': { + 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU', + 'en': 'Select GPU to train' + } + }, + } + + choice_dict = BaseUI.get_choices_from_dataclass(InferArguments) + default_dict = BaseUI.get_default_value_from_dataclass(InferArguments) + arguments = BaseUI.get_argument_names(InferArguments) + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.TabItem(elem_id='llm_infer', label=''): + default_device = 'cpu' + device_count = get_device_count() + if device_count > 0: + default_device = '0' + with gr.Blocks(): + infer_request = gr.State(None) + Model.build_ui(base_tab) + Runtime.build_ui(base_tab) + with gr.Row(): + gr.Dropdown( + elem_id='gpu_id', + multiselect=True, + choices=[str(i) for i in range(device_count)] + ['cpu'], + value=default_device, + scale=8) + infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4) + gr.Textbox(elem_id='port', lines=1, value='8000', scale=4) + chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height') + with gr.Row(): + prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True) + with gr.Tabs(visible=cls.is_multimodal): + with gr.TabItem(label='Image'): + image = gr.Image(type='filepath') + with gr.TabItem(label='Video'): + video = gr.Video() + with gr.TabItem(label='Audio'): + audio = gr.Audio(type='filepath') + + with gr.Row(): + clear_history = gr.Button(elem_id='clear_history') + submit = gr.Button(elem_id='submit') + + cls.element('load_checkpoint').click( + cls.deploy_model, list(base_tab.valid_elements().values()), + [cls.element('runtime_tab'), cls.element('running_tasks')]) + submit.click( + cls.send_message, + inputs=[ + cls.element('running_tasks'), + cls.element('template'), prompt, image, video, audio, infer_request, infer_model_type, + cls.element('system'), + cls.element('max_new_tokens'), + cls.element('temperature'), + cls.element('top_k'), + cls.element('top_p'), + cls.element('repetition_penalty') + ], + outputs=[prompt, chatbot, image, video, audio, infer_request], + queue=True) + + clear_history.click( + fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot, image, video, audio, infer_request]) + + base_tab.element('running_tasks').change( + partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], + list(cls.valid_elements().values()) + [cls.element('log')]) + Runtime.element('kill_task').click( + Runtime.kill_task, + [Runtime.element('running_tasks')], + [Runtime.element('running_tasks')] + [Runtime.element('log')], + ) + + @classmethod + def deploy(cls, *args): + deploy_args = cls.get_default_value_from_dataclass(DeployArguments) + kwargs = {} + kwargs_is_list = {} + other_kwargs = {} + more_params = {} + more_params_cmd = '' + keys = cls.valid_element_keys() + for key, value in zip(keys, args): + compare_value = deploy_args.get(key) + compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value + compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value + if key in deploy_args and compare_value_ui != compare_value_arg and value: + if isinstance(value, str) and re.fullmatch(cls.int_regex, value): + value = int(value) + elif isinstance(value, str) and re.fullmatch(cls.float_regex, value): + value = float(value) + elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value): + value = True if value.lower() == 'true' else False + kwargs[key] = value if not isinstance(value, list) else ' '.join(value) + kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False) + else: + other_kwargs[key] = value + if key == 'more_params' and value: + try: + more_params = json.loads(value) + except (JSONDecodeError or TypeError): + more_params_cmd = value + + kwargs.update(more_params) + model = kwargs.get('model') + if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')): + kwargs['ckpt_dir'] = kwargs.pop('model') + with open(os.path.join(kwargs['ckpt_dir'], 'args.json'), 'r', encoding='utf-8') as f: + _json = json.load(f) + kwargs['model_type'] = _json['model_type'] + kwargs['train_type'] = _json['train_type'] + deploy_args = DeployArguments( + **{ + key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value + for key, value in kwargs.items() + }) + if deploy_args.port in Runtime.get_all_ports(): + raise gr.Error(cls.locale('port_alert', cls.lang)['value']) + params = '' + sep = f'{cls.quote} {cls.quote}' + for e in kwargs: + if isinstance(kwargs[e], list): + params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} ' + elif e in kwargs_is_list and kwargs_is_list[e]: + all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()] + params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} ' + else: + params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} ' + if 'port' not in kwargs: + params += f'--port "{deploy_args.port}" ' + params += more_params_cmd + ' ' + devices = other_kwargs['gpu_id'] + devices = [d for d in devices if d] + assert (len(devices) == 1 or 'cpu' not in devices) + gpus = ','.join(devices) + cuda_param = '' + if gpus != 'cpu': + if is_torch_npu_available(): + cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}' + elif is_torch_cuda_available(): + cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}' + else: + cuda_param = '' + now = datetime.now() + time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}' + file_path = f'output/{deploy_args.model_type}-{time_str}' + if not os.path.exists(file_path): + os.makedirs(file_path, exist_ok=True) + log_file = os.path.join(os.getcwd(), f'{file_path}/run_deploy.log') + deploy_args.log_file = log_file + params += f'--log_file "{log_file}" ' + params += '--ignore_args_error true ' + if sys.platform == 'win32': + if cuda_param: + cuda_param = f'set {cuda_param} && ' + run_command = f'{cuda_param}start /b swift deploy {params} > {log_file} 2>&1' + else: + run_command = f'{cuda_param} nohup swift deploy {params} > {log_file} 2>&1 &' + return run_command, deploy_args, log_file + + @classmethod + def deploy_model(cls, *args): + run_command, deploy_args, log_file = cls.deploy(*args) + logger.info(f'Running deployment command: {run_command}') + os.system(run_command) + gr.Info(cls.locale('load_alert', cls.lang)['value']) + time.sleep(2) + running_task = Runtime.refresh_tasks(log_file) + return gr.update(open=True), running_task + + @classmethod + def register_clean_hook(cls): + signal.signal(signal.SIGINT, LLMInfer.signal_handler) + if os.name != 'nt': + signal.signal(signal.SIGTERM, LLMInfer.signal_handler) + + @staticmethod + def signal_handler(*args, **kwargs): + LLMInfer.clean_deployment() + sys.exit(0) + + @classmethod + def clear_session(cls): + return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), [] + + @classmethod + def _replace_tag_with_media(cls, infer_request: InferRequest): + total_history = [] + messages = deepcopy(infer_request.messages) + if messages[0]['role'] == 'system': + messages.pop(0) + for i in range(0, len(messages), 2): + slices = messages[i:i + 2] + if len(slices) == 2: + user, assistant = slices + else: + user = slices[0] + assistant = {'role': 'assistant', 'content': None} + user['content'] = (user['content'] or '').replace('', '').replace('