| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import warnings |
| | from typing import List, Literal |
| |
|
| | import torch |
| |
|
| |
|
| | def reshape_weight_task_tensors(task_tensors, weights): |
| | """ |
| | Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. |
| | |
| | Args: |
| | task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`. |
| | weights (`torch.Tensor`): The tensor to be reshaped. |
| | |
| | Returns: |
| | `torch.Tensor`: The reshaped tensor. |
| | """ |
| | new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim()) |
| | weights = weights.view(new_shape) |
| | return weights |
| |
|
| |
|
| | def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: |
| | """ |
| | Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction |
| | `density`. |
| | |
| | Args: |
| | tensor (`torch.Tensor`):The tensor to prune. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | |
| | Returns: |
| | `torch.Tensor`: The tensor with the pruned weights. |
| | """ |
| | mask = torch.zeros_like(tensor).reshape(-1) |
| | k = int(density * tensor.numel()) |
| | top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) |
| | mask[top_k[1]] = 1 |
| | return tensor * mask.reshape(tensor.shape) |
| |
|
| |
|
| | def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: |
| | """ |
| | Prune random values based on the specified fraction `density`. |
| | |
| | Args: |
| | tensor (`torch.Tensor`):The tensor to prune. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. |
| | |
| | Returns: |
| | `torch.Tensor`: The pruned tensor. |
| | """ |
| | mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) |
| | pruned_tensor = tensor * mask |
| | if rescale: |
| | torch.div(input=pruned_tensor, other=density) |
| | return pruned_tensor |
| |
|
| |
|
| | def prune( |
| | tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False |
| | ) -> torch.Tensor: |
| | """ |
| | Prune the values of task tensors based on the `method`. |
| | |
| | Args: |
| | tensor (`torch.Tensor`):The tensor to prune. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. |
| | rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. |
| | |
| | Returns: |
| | `torch.Tensor`: The pruned tensor. |
| | """ |
| | if density >= 1: |
| | warnings.warn(f"The density {density} is greater than or equal to 1, no pruning will be performed.") |
| | return tensor |
| | elif density < 0: |
| | raise ValueError(f"Density should be >= 0, got {density}") |
| | if method == "magnitude": |
| | return magnitude_based_pruning(tensor, density) |
| | elif method == "random": |
| | return random_pruning(tensor, density, rescale=rescale) |
| | else: |
| | raise ValueError(f"Unknown method {method}") |
| |
|
| |
|
| | def calculate_majority_sign_mask( |
| | tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" |
| | ) -> torch.Tensor: |
| | """ |
| | Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. |
| | |
| | Args: |
| | tensor (`torch.Tensor`):The tensor to get the mask from. |
| | method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. |
| | |
| | Returns: |
| | `torch.Tensor`: The majority sign mask. |
| | """ |
| |
|
| | sign = tensor.sign() |
| | if method == "total": |
| | sign_magnitude = tensor.sum(dim=0) |
| | elif method == "frequency": |
| | sign_magnitude = sign.sum(dim=0) |
| | else: |
| | raise RuntimeError(f'Unimplemented mask method "{method}"') |
| | majority_sign = torch.where(sign_magnitude >= 0, 1, -1) |
| | return sign == majority_sign |
| |
|
| |
|
| | def disjoint_merge(task_tensors: torch.Tensor, majority_sign_mask: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using disjoint merge. |
| | |
| | Args: |
| | task_tensors (`torch.Tensor`):The task tensors to merge. |
| | majority_sign_mask (`torch.Tensor`):The mask of the majority sign across the task tensors. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) |
| | num_params_preserved = majority_sign_mask.sum(dim=0) |
| | return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) |
| |
|
| |
|
| | def task_arithmetic(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using `task arithmetic`. |
| | |
| | Args: |
| | task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
| | weights (`torch.Tensor`):The weights of the task tensors. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | task_tensors = torch.stack(task_tensors, dim=0) |
| | |
| | weights = reshape_weight_task_tensors(task_tensors, weights) |
| | weighted_task_tensors = task_tensors * weights |
| | mixed_task_tensors = weighted_task_tensors.sum(dim=0) |
| | return mixed_task_tensors |
| |
|
| |
|
| | def magnitude_prune(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using `task arithmetic`. |
| | |
| | Args: |
| | task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
| | weights (`torch.Tensor`):The weights of the task tensors. |
| | density (`float`): The fraction of values to preserve. Should be in [0,1]. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | |
| | task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] |
| | task_tensors = torch.stack(task_tensors, dim=0) |
| | |
| | weights = reshape_weight_task_tensors(task_tensors, weights) |
| | weighted_task_tensors = task_tensors * weights |
| | mixed_task_tensors = weighted_task_tensors.sum(dim=0) |
| | return mixed_task_tensors |
| |
|
| |
|
| | def ties( |
| | task_tensors: List[torch.Tensor], |
| | weights: torch.Tensor, |
| | density: float, |
| | majority_sign_method: Literal["total", "frequency"] = "total", |
| | ) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using `ties`. |
| | |
| | Args: |
| | task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
| | weights (`torch.Tensor`):The weights of the task tensors. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | majority_sign_method (`str`): |
| | The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | |
| | task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] |
| | task_tensors = torch.stack(task_tensors, dim=0) |
| | |
| | majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) |
| | |
| | weights = reshape_weight_task_tensors(task_tensors, weights) |
| | weighted_task_tensors = task_tensors * weights |
| | |
| | mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) |
| | return mixed_task_tensors |
| |
|
| |
|
| | def dare_linear(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using `dare linear`. |
| | |
| | Args: |
| | task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
| | weights (`torch.Tensor`):The weights of the task tensors. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | |
| | task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] |
| | task_tensors = torch.stack(task_tensors, dim=0) |
| | |
| | weights = reshape_weight_task_tensors(task_tensors, weights) |
| | weighted_task_tensors = task_tensors * weights |
| | mixed_task_tensors = weighted_task_tensors.sum(dim=0) |
| | return mixed_task_tensors |
| |
|
| |
|
| | def dare_ties( |
| | task_tensors: List[torch.Tensor], |
| | weights: torch.Tensor, |
| | density: float, |
| | majority_sign_method: Literal["total", "frequency"] = "total", |
| | ) -> torch.Tensor: |
| | """ |
| | Merge the task tensors using `dare ties`. |
| | |
| | Args: |
| | task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
| | weights (`torch.Tensor`):The weights of the task tensors. |
| | density (`float`):The fraction of values to preserve. Should be in [0,1]. |
| | majority_sign_method (`str`): |
| | The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. |
| | |
| | Returns: |
| | `torch.Tensor`: The merged tensor. |
| | """ |
| | |
| | task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] |
| | task_tensors = torch.stack(task_tensors, dim=0) |
| | |
| | majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) |
| | |
| | weights = reshape_weight_task_tensors(task_tensors, weights) |
| | weighted_task_tensors = task_tensors * weights |
| | |
| | mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) |
| | return mixed_task_tensors |
| |
|