|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import inspect
|
| import logging
|
| from copy import deepcopy
|
| from dataclasses import dataclass
|
| from typing import Any
|
|
|
| import pyarrow as pa
|
| import pyarrow.types
|
| import torch
|
| from accelerate.utils import is_peft_model
|
| from packaging.version import Version
|
| from pyarrow import compute as pc
|
| from torch import nn
|
| from torch.nn.utils.rnn import pad_sequence
|
| from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments
|
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| from transformers.utils import (
|
| is_peft_available,
|
| is_torch_mlu_available,
|
| is_torch_npu_available,
|
| is_torch_xpu_available,
|
| )
|
|
|
| from ..data_utils import DatasetType, _get_dataset_format
|
| from ..trainer.utils import pad
|
|
|
|
|
| if is_peft_available():
|
| import peft
|
| from peft import PeftConfig, PeftModel, get_peft_model
|
|
|
|
|
| @dataclass
|
| class DPODataCollatorWithPadding:
|
| r"""
|
| DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
|
|
|
| Args:
|
| pad_token_id (`int` defaults to 0):
|
| The tokenizer's pad_token_id.
|
| is_encoder_decoder (`bool` or `None`, `optional`, defaults to `None`):
|
| Whether you model has an encoder_decoder architecture.
|
| """
|
|
|
| pad_token_id: int = 0
|
| is_encoder_decoder: bool | None = False
|
|
|
| def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
| padded_batch = {}
|
| for k in features[0].keys():
|
| if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
|
| if self.is_encoder_decoder:
|
| to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
|
|
| if (k.startswith("prompt")) and (k.endswith("input_ids")):
|
| if self.pad_token_id is None:
|
| raise ValueError(
|
| "Padding is enabled, but the tokenizer is not configured with a padding token."
|
| " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
|
| " before calling the trainer."
|
| )
|
| padding_value = self.pad_token_id
|
| elif k.endswith("_attention_mask"):
|
| padding_value = 0
|
| elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
|
| padding_value = -100
|
| else:
|
| raise ValueError(f"Unexpected key in batch '{k}'")
|
| padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
|
| else:
|
|
|
| if k.endswith("_input_ids"):
|
| if self.pad_token_id is None:
|
| raise ValueError(
|
| "Padding is enabled, but the tokenizer is not configured with a padding token."
|
| " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
|
| " before calling the trainer."
|
| )
|
| padding_value = self.pad_token_id
|
| elif k.endswith("_labels"):
|
| padding_value = -100
|
| elif k.endswith("_attention_mask"):
|
| padding_value = 0
|
| elif k.endswith("_pixel_values"):
|
| padding_value = 0
|
| else:
|
| raise ValueError(f"Unexpected key in batch '{k}'")
|
|
|
|
|
| if k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| padding_side = "left"
|
| else:
|
| padding_side = "right"
|
|
|
|
|
| if k.endswith("_pixel_values"):
|
| dtype = torch.float32
|
| else:
|
| dtype = torch.int64
|
|
|
|
|
| to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
|
| padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side)
|
| elif k.endswith("_logps"):
|
|
|
| padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
| else:
|
| padded_batch[k] = [ex[k] for ex in features]
|
|
|
| return padded_batch
|
|
|
|
|
| @dataclass
|
| class DataCollatorForChatML:
|
| """
|
| Data collator for ChatML format datasets.
|
| """
|
|
|
| tokenizer: PreTrainedTokenizerBase
|
| ignore_index: int = -100
|
| max_length: int = None
|
| prompt_key: str = "prompt"
|
| messages_key: str = "messages"
|
|
|
| def __post_init__(self):
|
| if self.tokenizer.pad_token_id is None:
|
| raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.")
|
| if self.max_length is None:
|
|
|
| self.max_length = min(self.tokenizer.model_max_length, 1024)
|
|
|
| def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
| input_ids = []
|
| attention_mask = []
|
| prompts_input_ids = []
|
| prompt_attention_mask = []
|
| labels = []
|
|
|
| for example in examples:
|
| formatted_prompt = example.get(self.prompt_key, None)
|
| if formatted_prompt is None:
|
| prompt = example[self.messages_key][:-1]
|
| formatted_prompt = self.tokenizer.apply_chat_template(
|
| prompt, add_generation_prompt=True, tokenize=False
|
| )
|
|
|
| if "input_ids" not in example:
|
| message = example[self.messages_key]
|
| formatted_message = self.tokenizer.apply_chat_template(
|
| message, add_generation_prompt=False, tokenize=False
|
| )
|
|
|
| tokenized_message = self.tokenizer(
|
| formatted_message,
|
| truncation=False,
|
| padding=False,
|
| return_tensors=None,
|
| add_special_tokens=False,
|
| return_offsets_mapping=True,
|
| )
|
| message_input_ids_full = tokenized_message["input_ids"]
|
| offsets = tokenized_message.get("offset_mapping")
|
|
|
| if offsets is not None:
|
| prompt_char_len = len(formatted_prompt)
|
| completion_start_idx_full = next(
|
| (idx for idx, (start, _) in enumerate(offsets) if start >= prompt_char_len),
|
| len(message_input_ids_full),
|
| )
|
| else:
|
| tokenized_prompt_full = self.tokenizer(
|
| formatted_prompt,
|
| truncation=False,
|
| padding=False,
|
| return_tensors=None,
|
| add_special_tokens=False,
|
| )
|
| completion_start_idx_full = len(tokenized_prompt_full["input_ids"])
|
|
|
| prompt_tokens_full = message_input_ids_full[:completion_start_idx_full]
|
| completion_input_ids_full = message_input_ids_full[completion_start_idx_full:]
|
|
|
| if self.max_length is not None and len(message_input_ids_full) > self.max_length:
|
| completion_ids = completion_input_ids_full
|
| if len(completion_ids) >= self.max_length:
|
| completion_ids = completion_ids[-self.max_length :]
|
| prompt_ids = []
|
| else:
|
| max_prompt_tokens = self.max_length - len(completion_ids)
|
| prompt_ids = prompt_tokens_full[-max_prompt_tokens:] if max_prompt_tokens > 0 else []
|
| message_input_ids = prompt_ids + completion_ids
|
| else:
|
| message_input_ids = message_input_ids_full
|
| prompt_ids = prompt_tokens_full
|
|
|
| input_ids.append(message_input_ids)
|
| attention_mask.append([1] * len(message_input_ids))
|
| current_prompt_ids = prompt_ids
|
| else:
|
| message_input_ids = example["input_ids"]
|
| input_ids.append(message_input_ids)
|
| if "attention_mask" in example:
|
| attention_mask.append(example["attention_mask"])
|
| else:
|
| attention_mask.append([1] * len(message_input_ids))
|
|
|
| tokenized_prompt = self.tokenizer(
|
| formatted_prompt,
|
| truncation=True,
|
| max_length=len(message_input_ids),
|
| padding=False,
|
| return_tensors=None,
|
| add_special_tokens=False,
|
| )
|
| current_prompt_ids = tokenized_prompt["input_ids"]
|
|
|
| prompts_input_ids.append(current_prompt_ids)
|
| prompt_attention_mask.append([1] * len(current_prompt_ids))
|
|
|
| label = [self.ignore_index] * len(input_ids[-1])
|
| completion_start_idx = len(current_prompt_ids)
|
| label[completion_start_idx:] = input_ids[-1][completion_start_idx:]
|
| labels.append(label)
|
|
|
|
|
| input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
|
| attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask]
|
| labels = [torch.tensor(label, dtype=torch.long) for label in labels]
|
| input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
|
| attention_mask = pad(attention_mask, padding_side="left", padding_value=0)
|
| labels = pad(labels, padding_side="left", padding_value=self.ignore_index)
|
|
|
| prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids]
|
| prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask]
|
| prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
|
| prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels,
|
| "prompts": prompts_input_ids,
|
| "prompt_attention_mask": prompt_attention_mask,
|
| }
|
|
|
|
|
| def truncate_right(
|
| input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Truncates the input tensor from the right side after the first occurrence of the stop token.
|
|
|
| Args:
|
| input_ids (`torch.Tensor`):
|
| The tensor containing the responses to be truncated
|
| stop_token_id (`int`):
|
| The token ID representing the stop token where truncation occurs
|
| pad_token_id (`int`):
|
| The token ID representing the pad token used to fill the truncated responses
|
|
|
| Returns:
|
| tuple:
|
| - `output_ids` (`torch.Tensor`):
|
| The truncated responses tensor with pad tokens filled after the stop token
|
| - `mask` (`torch.Tensor`):
|
| The mask tensor to indicate the padding tokens
|
| """
|
| trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1)
|
| new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]]
|
| idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size)
|
| output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id)
|
| mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0)
|
| return output_ids, mask
|
|
|
|
|
| def add_bos_token_if_needed(
|
| bos_token_id: int | None,
|
| prompt_len_input_ids: int,
|
| prompt_tokens: dict[str, list[int]],
|
| chosen_prompt_len_input_ids: int,
|
| chosen_tokens: dict[str, list[int]],
|
| rejected_prompt_len_input_ids: int,
|
| rejected_tokens: dict[str, list[int]],
|
| ):
|
| if bos_token_id is not None:
|
| if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
|
| prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
|
| prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
|
| if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
|
| chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
|
| chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
|
| if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
|
| rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
|
| rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
|
| return prompt_tokens, chosen_tokens, rejected_tokens
|
|
|
|
|
| def add_eos_token_if_needed(
|
| eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]]
|
| ):
|
| if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
|
| chosen_tokens["input_ids"].append(eos_token_id)
|
| chosen_tokens["attention_mask"].append(1)
|
| if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
|
| rejected_tokens["input_ids"].append(eos_token_id)
|
| rejected_tokens["attention_mask"].append(1)
|
| return chosen_tokens, rejected_tokens
|
|
|
|
|
| def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor:
|
| """
|
| Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving the position of the
|
| first True in each "row".
|
|
|
| Returns the length of the rows (bools.size(-1)) if no element is True in a given row.
|
|
|
| Args:
|
| bools (`torch.Tensor`):
|
| An N-dimensional boolean tensor.
|
| dtype (`torch.dtype`, optional):
|
| The desired data type of the output tensor. Defaults to `torch.long`.
|
|
|
| Returns:
|
| `torch.Tensor`:
|
| An (N-1)-dimensional tensor of integers indicating the position of the first True in each row. If no True
|
| value is found in a row, returns the length of the row.
|
| """
|
| row_len = bools.size(-1)
|
| zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
|
| return torch.min(zero_or_index, dim=-1).values
|
|
|
|
|
| def get_reward(
|
| model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
|
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """
|
| Computes the reward logits and the rewards for a given model and query responses.
|
|
|
| Args:
|
| model (`torch.nn.Module`):
|
| The model used to compute the reward logits.
|
| query_responses (`torch.Tensor`):
|
| The tensor containing the query responses.
|
| pad_token_id (`int`):
|
| The token ID representing the pad token.
|
| context_length (`int`):
|
| The length of the context in the query responses.
|
|
|
| Returns:
|
| tuple:
|
| - `reward_logits` (`torch.Tensor`):
|
| The logits for the reward model.
|
| - `final_rewards` (`torch.Tensor`):
|
| The final rewards for each query response.
|
| - `sequence_lengths` (`torch.Tensor`):
|
| The lengths of the sequences in the query responses.
|
| """
|
| attention_mask = query_responses != pad_token_id
|
| position_ids = attention_mask.cumsum(1) - attention_mask.long()
|
| lm_backbone = getattr(model, model.base_model_prefix)
|
| input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
|
| output = lm_backbone(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| position_ids=position_ids,
|
| return_dict=True,
|
| output_hidden_states=True,
|
| use_cache=False,
|
| )
|
| reward_logits = model.score(output.hidden_states[-1])
|
| sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
|
|
|
| return (
|
| reward_logits,
|
| reward_logits[
|
| torch.arange(reward_logits.size(0), device=reward_logits.device),
|
| sequence_lengths,
|
| ].squeeze(-1),
|
| sequence_lengths,
|
| )
|
|
|
|
|
| def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
|
| r"""
|
| Prepare a k-bit quantized transformers model for training (PEFT/QLoRA).
|
| """
|
| loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
|
| quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"]
|
| is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr(
|
| model, "hqq_quantized", False
|
| )
|
|
|
| if gradient_checkpointing_kwargs is None:
|
| gradient_checkpointing_kwargs = {}
|
|
|
| for _, param in model.named_parameters():
|
|
|
| param.requires_grad = False
|
|
|
|
|
| if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing:
|
| if hasattr(model, "enable_input_require_grads"):
|
| model.enable_input_require_grads()
|
| else:
|
|
|
| def make_inputs_require_grad(module, input, output):
|
| output.requires_grad_(True)
|
|
|
| model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
|
| supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
| inspect.signature(model.gradient_checkpointing_enable).parameters
|
| )
|
| gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {}
|
| model.gradient_checkpointing_enable(**gc_kwargs)
|
|
|
| return model
|
|
|
|
|
| def enable_gradient_checkpointing(
|
| model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None
|
| ) -> PreTrainedModel:
|
| """Enables gradient checkpointing for the model."""
|
|
|
| if is_peft_model(model):
|
| model.base_model.gradient_checkpointing_enable()
|
|
|
| else:
|
| model.gradient_checkpointing_enable()
|
|
|
| gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {}
|
| use_reentrant = (
|
| "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| )
|
|
|
| if use_reentrant:
|
| if hasattr(model, "enable_input_require_grads"):
|
| model.enable_input_require_grads()
|
| else:
|
|
|
| def make_inputs_require_grad(module, input, output):
|
| output.requires_grad_(True)
|
|
|
| model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
|
| return model
|
|
|
|
|
| def prepare_peft_model(
|
| model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
|
| ) -> PreTrainedModel:
|
| """Prepares a model for PEFT training."""
|
| if not is_peft_available():
|
| raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.")
|
|
|
| if isinstance(model, PeftModel) and peft_config is not None:
|
| raise ValueError(
|
| "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge and "
|
| "unload the existing adapter, save the resulting base model, and then pass that base model along with the "
|
| "new `peft_config` to the trainer."
|
| )
|
|
|
|
|
| is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
|
|
| is_sharded_qlora = False
|
| if getattr(model, "is_loaded_in_4bit", False):
|
|
|
| for _, param in model.named_parameters():
|
| if param.__class__.__name__ == "Params4bit":
|
| is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
| break
|
|
|
|
|
| if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
|
| model = prepare_model_for_kbit_training(
|
| model,
|
| use_gradient_checkpointing=args.gradient_checkpointing,
|
| gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
|
| )
|
|
|
| args.gradient_checkpointing = False
|
| elif args.gradient_checkpointing:
|
| model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)
|
|
|
|
|
| if peft_config is not None:
|
| if (
|
| Version(peft.__version__) >= Version("0.12")
|
| and getattr(model, "is_loaded_in_4bit", False)
|
| and is_sharded_qlora
|
| ):
|
| model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
| else:
|
| model = get_peft_model(model, peft_config)
|
|
|
|
|
| if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
| peft_module_casting_to_bf16(model)
|
|
|
| return model
|
|
|
|
|
| def pad_to_length(tensor: torch.Tensor, length: int, pad_value: int | float, dim: int = -1) -> torch.Tensor:
|
| if tensor.size(dim) >= length:
|
| return tensor
|
| else:
|
| pad_size = list(tensor.shape)
|
| pad_size[dim] = length - tensor.size(dim)
|
| return torch.cat(
|
| [
|
| tensor,
|
| pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
|
| ],
|
| dim=dim,
|
| )
|
|
|
|
|
| def empty_cache() -> None:
|
| """Empties the cache of the available torch device.
|
|
|
| This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) and empties the cache of
|
| the first available device it finds.
|
|
|
| If none of the specific devices are available, it defaults to emptying the CUDA cache.
|
| """
|
| if is_torch_xpu_available():
|
| torch.xpu.empty_cache()
|
| elif is_torch_mlu_available():
|
| torch.mlu.empty_cache()
|
| elif is_torch_npu_available():
|
| torch.npu.empty_cache()
|
| else:
|
| torch.cuda.empty_cache()
|
|
|
|
|
| def peft_module_casting_to_bf16(model):
|
| for name, module in model.named_modules():
|
| if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
|
| module = module.to(torch.float32)
|
| elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
| if hasattr(module, "weight"):
|
| if module.weight.dtype == torch.float32:
|
| module = module.to(torch.bfloat16)
|
|
|
|
|
| LAYER_PATTERNS = [
|
| "transformer.h.{layer}",
|
| "model.decoder.layers.{layer}",
|
| "gpt_neox.layers.{layer}",
|
| "model.layers.{layer}",
|
| ]
|
|
|
|
|
| def create_reference_model(
|
| model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None
|
| ) -> nn.Module:
|
| """
|
| Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
|
|
|
| Args:
|
| model ([`nn.Module`]): The model to be copied.
|
| num_shared_layers (`int`, *optional*):
|
| The number of initial layers that are shared between both models and kept frozen.
|
| pattern (`str`, *optional*): The shared layers are selected with a string pattern
|
| (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.
|
|
|
| Returns:
|
| [`nn.Module`]
|
| """
|
| if is_deepspeed_zero3_enabled():
|
| raise ValueError(
|
| "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`."
|
| )
|
|
|
| parameter_names = [n for n, _ in model.named_parameters()]
|
| ref_model = deepcopy(model)
|
|
|
|
|
| if num_shared_layers is None:
|
| for param_name in parameter_names:
|
| param = ref_model.get_parameter(param_name)
|
| param.requires_grad = False
|
| return ref_model.eval()
|
|
|
|
|
| if pattern is not None:
|
| pattern = pattern.format(layer=num_shared_layers)
|
| else:
|
| for pattern_candidate in LAYER_PATTERNS:
|
| pattern_candidate = pattern_candidate.format(layer=num_shared_layers)
|
| if any(pattern_candidate in name for name in parameter_names):
|
| pattern = pattern_candidate
|
| break
|
|
|
| if pattern is None:
|
| raise ValueError("Layer pattern could not be matched.")
|
|
|
|
|
| shared_param_list = []
|
| unshared_param_list = []
|
|
|
| shared_parameter = True
|
| for name, _param in model.named_parameters():
|
| if pattern in name:
|
| shared_parameter = False
|
| if shared_parameter:
|
| shared_param_list.append(name)
|
| else:
|
| unshared_param_list.append(name)
|
|
|
|
|
| for param_name in shared_param_list:
|
| param = model.get_parameter(param_name)
|
| param.requires_grad = False
|
|
|
| _ref_param = ref_model.get_parameter(param_name)
|
|
|
|
|
| for param_name in unshared_param_list:
|
| param = ref_model.get_parameter(param_name)
|
| param.requires_grad = False
|
|
|
| if pattern is not None and len(unshared_param_list) == 0:
|
| logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.")
|
|
|
| return ref_model.eval()
|
|
|
|
|
| def truncate_dataset(
|
| dataset: DatasetType,
|
| max_length: int,
|
| map_kwargs: dict[str, Any] | None = None,
|
| ) -> DatasetType:
|
| r"""
|
| Truncate sequences in a dataset to a specified `max_length`.
|
|
|
| Args:
|
| dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
|
| Dataset to truncate.
|
| max_length (`int`):
|
| Maximum sequence length to truncate to.
|
| map_kwargs (`dict`, *optional*):
|
| Additional keyword arguments to pass to the dataset's map method when truncating examples.
|
|
|
| Returns:
|
| [`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The dataset with truncated sequences.
|
|
|
| Example:
|
| ```python
|
| >>> from datasets import Dataset
|
|
|
| >>> examples = {
|
| ... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| ... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| ... }
|
| >>> dataset = Dataset.from_dict(examples)
|
| >>> truncated_dataset = truncate_dataset(dataset, max_length=2)
|
| >>> truncated_dataset[:]
|
| {'input_ids': [[1, 2], [4, 5], [8]],
|
| 'attention_mask': [[0, 1], [0, 0], [1]]}
|
| ```
|
| """
|
| if map_kwargs is None:
|
| map_kwargs = {}
|
|
|
| def truncate(examples):
|
| truncated_columns = []
|
| for column in examples.columns:
|
| if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
|
| column = pc.list_slice(column, 0, max_length)
|
| truncated_columns.append(column)
|
| return pa.Table.from_arrays(truncated_columns, names=examples.column_names)
|
|
|
| format = _get_dataset_format(dataset)
|
| dataset = dataset.with_format("arrow")
|
| dataset = dataset.map(truncate, batched=True, **map_kwargs)
|
| dataset = dataset.with_format(**format)
|
| return dataset
|
|
|