# Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only # della + live audit report by Naphula import logging from enum import Enum from typing import Any, Dict, List, Optional, Tuple import torch from pydantic import BaseModel from typing_extensions import Literal, override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.merge_methods.base import ( ConfigParameterDef, MergeMethod, MergeTensorInput, ) from mergekit.sparsify import RescaleNorm, SparsificationMethod, sparsify class ConsensusMethod(str, Enum): count = "count" sum = "sum" class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): consensus_method: Optional[ConsensusMethod] sparsification_method: Optional[SparsificationMethod] default_normalize: bool default_rescale: bool method_name: str method_pretty_name: Optional[str] method_reference_url: Optional[str] def name(self) -> str: return self.method_name @override def pretty_name(self) -> Optional[str]: return self.method_pretty_name @override def reference_url(self) -> Optional[str]: return self.method_reference_url def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="int8_mask", required=False, default_value=False), ConfigParameterDef( name="normalize", required=False, default_value=self.default_normalize ), ConfigParameterDef( name="rescale", required=False, default_value=self.default_rescale ), ConfigParameterDef(name="lambda", required=False, default_value=1.0), ] def tensor_parameters(self) -> List[ConfigParameterDef]: res = [ ConfigParameterDef(name="weight", required=True), ConfigParameterDef(name="density", required=False, default_value=1.0), ] if self.sparsification_method == SparsificationMethod.magnitude_outliers: res.append( ConfigParameterDef( name="gamma", default_value=0.01, ) ) if self.sparsification_method == SparsificationMethod.della_magprune: res.append( ConfigParameterDef( name="epsilon", default_value=0.15, ) ) return res def make_task( self, output_weight: WeightInfo, tensors: MergeTensorInput, base_model: Optional[ModelReference], parameters: ImmutableMap[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Task: return GTATask( method=self, tensors=tensors, base_model=base_model, tensor_parameters=tensor_parameters, int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], lambda_=parameters["lambda"], rescale_norm=RescaleNorm.l1 if parameters["rescale"] else None, weight_info=output_weight, ) class GTATask(Task[torch.Tensor]): method: GeneralizedTaskArithmeticMerge tensors: MergeTensorInput base_model: ModelReference weight_info: WeightInfo tensor_parameters: ImmutableMap[ModelReference, Any] int8_mask: bool normalize: bool lambda_: float rescale_norm: Optional[RescaleNorm] def uses_accelerator(self) -> bool: return True def arguments(self) -> Dict[str, Task]: return {"tensors": self.tensors} def execute( self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs, ) -> torch.Tensor: # collect task vectors tvs, base = get_task_vectors( self.weight_info, self.base_model, tensors, tensor_parameters=self.tensor_parameters.data, ) # --- LIVE AUDIT CHART --- if tvs: log_della_audit( self.weight_info.name, self.base_model, tvs, self.lambda_, self.method.method_pretty_name ) # ------------------------ if not tvs: return base # sparsify if self.method.sparsification_method: for tv_info in tvs: kwargs = {} if "gamma" in tv_info: kwargs["gamma"] = tv_info["gamma"] if "epsilon" in tv_info: kwargs["epsilon"] = tv_info["epsilon"] tv_info["delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, rescale_norm=self.rescale_norm, **kwargs, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) while len(deltas.shape) > len(weights.shape): weights.unsqueeze_(-1) weighted_deltas = deltas * weights # get sign consensus and mix deltas if self.method.consensus_method: mask_dtype = torch.int8 if self.int8_mask else base.dtype mask = get_mask( weighted_deltas, method=self.method.consensus_method, mask_dtype=mask_dtype, ) mixed_delta = (weighted_deltas * mask).sum(dim=0) divisor = (weights * mask).sum(dim=0) divisor[divisor == 0] = 1 else: mixed_delta = weighted_deltas.sum(dim=0) divisor = weights.sum(dim=0) divisor[divisor.abs() < 1e-8] = 1 if self.normalize: mixed_delta /= divisor if self.lambda_ != 1: mixed_delta *= self.lambda_ return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: return self.tensors.group_label() def get_task_vectors( weight_info: WeightInfo, base_model: ModelReference, tensors: ImmutableMap[ModelReference, torch.Tensor], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Tuple[List[Dict[str, Any]], torch.Tensor]: keys = list(tensors.keys()) base = tensors[base_model] parameter_name = weight_info.name res = [] for model in keys: if model == base_model: continue x = tensors[model].to(base.dtype) if x.shape != base.shape: if weight_info.is_embed: x = x[: base.shape[0], : base.shape[1]] logging.warning(f"Using submatrix of {model}:{parameter_name}") else: logging.warning( f"skipping {model}:{parameter_name} due to size mismatch" ) continue delta = x - base del x del tensors[model] d = {} d["model"] = model d["delta"] = delta for p in tensor_parameters[model]: d[p] = tensor_parameters[model][p] res.append(d) return res, base def get_mask( delta: torch.Tensor, method: Literal["sum", "count"] = "sum", mask_dtype: Optional[torch.dtype] = None, ): """Returns a mask determining which delta vectors should be merged into the final model. For the methodology described in the TIES paper use 'sum'. For a simpler naive count of signs, use 'count'.""" if mask_dtype is None: mask_dtype = delta.dtype sign = delta.sign().to(mask_dtype) if method == "sum": sign_weight = delta.sum(dim=0) majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 del sign_weight elif method == "count": majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 else: raise RuntimeError(f'Unimplemented mask method "{method}"') return sign == majority_sign def log_della_audit( layer_name: str, base_model: ModelReference, tvs: List[Dict[str, Any]], global_lambda: float, method_name: str ): """Prints and saves a bar chart of DELLA/Task Arithmetic distribution based on actual Delta Norms.""" base_name = str(base_model.model.path).split("\\")[-1].split("/")[-1][:50] bar_char = "█" lines = [f"\n[{method_name} Audit] Layer: {layer_name} | Lambda={global_lambda:.2f}"] lines.append(f" [BASE] {base_name:<50}") # 1. Calculate stats stats = [] total_impact = 0.0 for tv in tvs: model_name = str(tv['model'].model.path).split("\\")[-1].split("/")[-1][:50] weight = tv.get('weight', 0.0) density = tv.get('density', 1.0) epsilon = tv.get('epsilon', None) delta = tv.get('delta', None) norm = 0.0 if delta is not None: # Use float32 for norm calculation to be safe norm = torch.norm(delta.float()).item() # Effective contribution magnitude = Weight * Norm # This shows how much this model is actually moving the weights impact = weight * norm total_impact += impact stats.append({ 'name': model_name, 'weight': weight, 'density': density, 'epsilon': epsilon, 'norm': norm, 'impact': impact }) # Sort by name for consistent logs stats.sort(key=lambda x: x['name']) # 2. Generate bars for s in stats: # Calculate percentage relative to the sum of all impacts (Share of Voice) pct = (s['impact'] / total_impact * 100) if total_impact > 0 else 0.0 # Bar length (max 50 chars for 100%) bar_len = int(max(0, min(50, pct / 2))) bar = bar_char * bar_len # Format info string # W=Weight, D=Density, N=DeltaNorm info = f"W:{s['weight']:.2f} D:{s['density']:.2f} N:{s['norm']:.2f}" if s['epsilon'] is not None: info += f" E:{s['epsilon']:.2f}" lines.append(f" {s['name']:<50}: {bar:<50} {pct:5.1f}% ({info})") log_entry = "\n".join(lines) print(log_entry) with open("della_audit.log", "a", encoding="utf-8") as f: f.write(log_entry + "\n")