Spaces:
Paused
Paused
Upload 129 files
Browse files- obliteratus/abliterate.py +88 -25
- tests/test_abliterate.py +13 -3
obliteratus/abliterate.py
CHANGED
|
@@ -2713,7 +2713,10 @@ class AbliterationPipeline:
|
|
| 2713 |
if norm_preserve and original_norm > 0:
|
| 2714 |
new_norm = W_slice.norm().item()
|
| 2715 |
if new_norm > 0:
|
| 2716 |
-
|
|
|
|
|
|
|
|
|
|
| 2717 |
|
| 2718 |
elif W.shape[1] == hidden_dim:
|
| 2719 |
# Transposed: W is (attn_dim, hidden_dim), rows by head
|
|
@@ -2729,7 +2732,10 @@ class AbliterationPipeline:
|
|
| 2729 |
if norm_preserve and original_norm > 0:
|
| 2730 |
new_norm = W_slice.norm().item()
|
| 2731 |
if new_norm > 0:
|
| 2732 |
-
|
|
|
|
|
|
|
|
|
|
| 2733 |
|
| 2734 |
if is_quantized:
|
| 2735 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
@@ -2913,25 +2919,24 @@ class AbliterationPipeline:
|
|
| 2913 |
# ── Guard: compound norm amplification ────────────────────────
|
| 2914 |
# When true_iterative_refinement is disabled, subsequent passes
|
| 2915 |
# re-apply the SAME projection directions without re-probing.
|
| 2916 |
-
# With norm_preserve=True
|
| 2917 |
-
#
|
| 2918 |
-
#
|
| 2919 |
-
#
|
| 2920 |
-
#
|
| 2921 |
-
#
|
| 2922 |
-
#
|
|
|
|
| 2923 |
#
|
| 2924 |
-
# Fix: cap to 1 pass when not re-probing + norm-preserving
|
| 2925 |
-
#
|
| 2926 |
-
# amplification in this configuration.
|
| 2927 |
effective_passes = self.refinement_passes
|
| 2928 |
if (effective_passes > 1
|
| 2929 |
and not self.true_iterative_refinement
|
| 2930 |
-
and self.norm_preserve
|
| 2931 |
-
and self.regularization > 0):
|
| 2932 |
self.log(
|
| 2933 |
f"Capping refinement_passes from {effective_passes} to 1: "
|
| 2934 |
-
f"norm_preserve
|
| 2935 |
f"compound amplification (directions are not re-extracted)"
|
| 2936 |
)
|
| 2937 |
effective_passes = 1
|
|
@@ -3355,14 +3360,39 @@ class AbliterationPipeline:
|
|
| 3355 |
break
|
| 3356 |
if lm_head_name is not None:
|
| 3357 |
lm_reg = (1.0 - self.reflection_strength) if self.invert_refusal else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3358 |
for dir_idx in range(subspace_on_device.shape[0]):
|
| 3359 |
d = subspace_on_device[dir_idx].unsqueeze(-1)
|
| 3360 |
lm_head_count += self._project_out_advanced(
|
| 3361 |
model, d, [lm_head_name],
|
| 3362 |
-
norm_preserve=self.norm_preserve,
|
| 3363 |
regularization=lm_reg,
|
| 3364 |
)
|
| 3365 |
del d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3366 |
del subspace_on_device
|
| 3367 |
if lm_head_count > 0:
|
| 3368 |
total_modified += lm_head_count
|
|
@@ -4042,7 +4072,12 @@ class AbliterationPipeline:
|
|
| 4042 |
if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
|
| 4043 |
continue # Skip — weight is degenerate after projection
|
| 4044 |
if abs(new_norm - original_norm) > 1e-6:
|
| 4045 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4046 |
|
| 4047 |
@staticmethod
|
| 4048 |
def _project_out_advanced(
|
|
@@ -4099,7 +4134,14 @@ class AbliterationPipeline:
|
|
| 4099 |
new_norm_sq = max(0.0, original_norm_sq - scale * (2 - scale) * coeff_norm_sq)
|
| 4100 |
if new_norm_sq > 0:
|
| 4101 |
import math
|
| 4102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4103 |
|
| 4104 |
if is_quantized:
|
| 4105 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
@@ -4124,7 +4166,10 @@ class AbliterationPipeline:
|
|
| 4124 |
new_norm_sq = max(0.0, original_norm_sq - scale * (2 - scale) * coeff_norm_sq)
|
| 4125 |
if new_norm_sq > 0:
|
| 4126 |
import math
|
| 4127 |
-
|
|
|
|
|
|
|
|
|
|
| 4128 |
|
| 4129 |
if is_quantized:
|
| 4130 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
@@ -4227,7 +4272,10 @@ class AbliterationPipeline:
|
|
| 4227 |
if norm_preserve and original_norm > 0:
|
| 4228 |
new_norm = W.norm().item()
|
| 4229 |
if new_norm > 0:
|
| 4230 |
-
|
|
|
|
|
|
|
|
|
|
| 4231 |
count += 1
|
| 4232 |
elif W.shape[0] == d.shape[0]:
|
| 4233 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
@@ -4237,7 +4285,10 @@ class AbliterationPipeline:
|
|
| 4237 |
if norm_preserve and original_norm > 0:
|
| 4238 |
new_norm = W.norm().item()
|
| 4239 |
if new_norm > 0:
|
| 4240 |
-
|
|
|
|
|
|
|
|
|
|
| 4241 |
count += 1
|
| 4242 |
|
| 4243 |
if count > 0:
|
|
@@ -4809,7 +4860,10 @@ class AbliterationPipeline:
|
|
| 4809 |
if norm_preserve and original_norm > 0:
|
| 4810 |
new_norm = W.norm().item()
|
| 4811 |
if new_norm > 0:
|
| 4812 |
-
|
|
|
|
|
|
|
|
|
|
| 4813 |
count += 1
|
| 4814 |
elif W.shape[0] == d.shape[0]:
|
| 4815 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
@@ -4823,7 +4877,10 @@ class AbliterationPipeline:
|
|
| 4823 |
if norm_preserve and original_norm > 0:
|
| 4824 |
new_norm = W.norm().item()
|
| 4825 |
if new_norm > 0:
|
| 4826 |
-
|
|
|
|
|
|
|
|
|
|
| 4827 |
count += 1
|
| 4828 |
|
| 4829 |
if is_quantized and count > 0:
|
|
@@ -4907,7 +4964,10 @@ class AbliterationPipeline:
|
|
| 4907 |
if norm_preserve and original_norm > 0:
|
| 4908 |
new_norm = W.norm().item()
|
| 4909 |
if new_norm > 0:
|
| 4910 |
-
|
|
|
|
|
|
|
|
|
|
| 4911 |
count += 1
|
| 4912 |
elif W.shape[0] == d.shape[0]:
|
| 4913 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
@@ -4921,7 +4981,10 @@ class AbliterationPipeline:
|
|
| 4921 |
if norm_preserve and original_norm > 0:
|
| 4922 |
new_norm = W.norm().item()
|
| 4923 |
if new_norm > 0:
|
| 4924 |
-
|
|
|
|
|
|
|
|
|
|
| 4925 |
count += 1
|
| 4926 |
|
| 4927 |
if is_quantized and count > 0:
|
|
|
|
| 2713 |
if norm_preserve and original_norm > 0:
|
| 2714 |
new_norm = W_slice.norm().item()
|
| 2715 |
if new_norm > 0:
|
| 2716 |
+
ratio = original_norm / new_norm
|
| 2717 |
+
if ratio > 1.10:
|
| 2718 |
+
ratio = 1.10
|
| 2719 |
+
W_slice.mul_(ratio)
|
| 2720 |
|
| 2721 |
elif W.shape[1] == hidden_dim:
|
| 2722 |
# Transposed: W is (attn_dim, hidden_dim), rows by head
|
|
|
|
| 2732 |
if norm_preserve and original_norm > 0:
|
| 2733 |
new_norm = W_slice.norm().item()
|
| 2734 |
if new_norm > 0:
|
| 2735 |
+
ratio = original_norm / new_norm
|
| 2736 |
+
if ratio > 1.10:
|
| 2737 |
+
ratio = 1.10
|
| 2738 |
+
W_slice.mul_(ratio)
|
| 2739 |
|
| 2740 |
if is_quantized:
|
| 2741 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
|
|
| 2919 |
# ── Guard: compound norm amplification ────────────────────────
|
| 2920 |
# When true_iterative_refinement is disabled, subsequent passes
|
| 2921 |
# re-apply the SAME projection directions without re-probing.
|
| 2922 |
+
# With norm_preserve=True, this creates pathological amplification:
|
| 2923 |
+
# each pass removes some energy, then norm-restoration rescales
|
| 2924 |
+
# the entire weight matrix UP to compensate, amplifying non-refusal
|
| 2925 |
+
# components. With regularization > 0, the partial removal makes
|
| 2926 |
+
# this especially severe (residual refusal is re-projected each
|
| 2927 |
+
# pass), but even regularization=0 causes drift because the second
|
| 2928 |
+
# pass projects from already-rescaled weights, finding phantom
|
| 2929 |
+
# residuals from floating-point imprecision that compound.
|
| 2930 |
#
|
| 2931 |
+
# Fix: cap to 1 pass when not re-probing + norm-preserving,
|
| 2932 |
+
# since extra passes without re-extraction are purely destructive.
|
|
|
|
| 2933 |
effective_passes = self.refinement_passes
|
| 2934 |
if (effective_passes > 1
|
| 2935 |
and not self.true_iterative_refinement
|
| 2936 |
+
and self.norm_preserve):
|
|
|
|
| 2937 |
self.log(
|
| 2938 |
f"Capping refinement_passes from {effective_passes} to 1: "
|
| 2939 |
+
f"norm_preserve without re-probing causes "
|
| 2940 |
f"compound amplification (directions are not re-extracted)"
|
| 2941 |
)
|
| 2942 |
effective_passes = 1
|
|
|
|
| 3360 |
break
|
| 3361 |
if lm_head_name is not None:
|
| 3362 |
lm_reg = (1.0 - self.reflection_strength) if self.invert_refusal else 0.0
|
| 3363 |
+
# Use bulk norm preservation for lm_head: capture norm
|
| 3364 |
+
# ONCE before all directions, restore ONCE after. Per-
|
| 3365 |
+
# direction rescaling on lm_head is especially destructive
|
| 3366 |
+
# because it directly distorts token logits — amplifying
|
| 3367 |
+
# non-refusal vocabulary embeddings causes degenerate
|
| 3368 |
+
# generation (repeated punctuation / gibberish).
|
| 3369 |
+
lm_head_obj = getattr(model, lm_head_name, None)
|
| 3370 |
+
lm_multi_dir = (
|
| 3371 |
+
subspace_on_device.shape[0] > 1
|
| 3372 |
+
and self.norm_preserve
|
| 3373 |
+
and lm_head_obj is not None
|
| 3374 |
+
and hasattr(lm_head_obj, "weight")
|
| 3375 |
+
)
|
| 3376 |
+
lm_original_norm = 0.0
|
| 3377 |
+
if lm_multi_dir:
|
| 3378 |
+
lm_original_norm = lm_head_obj.weight.data.norm().item()
|
| 3379 |
for dir_idx in range(subspace_on_device.shape[0]):
|
| 3380 |
d = subspace_on_device[dir_idx].unsqueeze(-1)
|
| 3381 |
lm_head_count += self._project_out_advanced(
|
| 3382 |
model, d, [lm_head_name],
|
| 3383 |
+
norm_preserve=self.norm_preserve and not lm_multi_dir,
|
| 3384 |
regularization=lm_reg,
|
| 3385 |
)
|
| 3386 |
del d
|
| 3387 |
+
# Restore lm_head norm once after all directions
|
| 3388 |
+
if lm_multi_dir and lm_original_norm > 0 and lm_head_obj is not None:
|
| 3389 |
+
new_norm = lm_head_obj.weight.data.norm().item()
|
| 3390 |
+
if new_norm > 0 and not math.isnan(new_norm) and not math.isinf(new_norm):
|
| 3391 |
+
ratio = lm_original_norm / new_norm
|
| 3392 |
+
if ratio > 1.10:
|
| 3393 |
+
ratio = 1.10
|
| 3394 |
+
if abs(ratio - 1.0) > 1e-6:
|
| 3395 |
+
lm_head_obj.weight.data.mul_(ratio)
|
| 3396 |
del subspace_on_device
|
| 3397 |
if lm_head_count > 0:
|
| 3398 |
total_modified += lm_head_count
|
|
|
|
| 4072 |
if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
|
| 4073 |
continue # Skip — weight is degenerate after projection
|
| 4074 |
if abs(new_norm - original_norm) > 1e-6:
|
| 4075 |
+
ratio = original_norm / new_norm
|
| 4076 |
+
# Cap amplification to prevent compound norm drift across
|
| 4077 |
+
# layers. Uncapped amplification destroys coherence.
|
| 4078 |
+
if ratio > 1.10:
|
| 4079 |
+
ratio = 1.10
|
| 4080 |
+
param.data.mul_(ratio)
|
| 4081 |
|
| 4082 |
@staticmethod
|
| 4083 |
def _project_out_advanced(
|
|
|
|
| 4134 |
new_norm_sq = max(0.0, original_norm_sq - scale * (2 - scale) * coeff_norm_sq)
|
| 4135 |
if new_norm_sq > 0:
|
| 4136 |
import math
|
| 4137 |
+
ratio = math.sqrt(original_norm_sq / new_norm_sq)
|
| 4138 |
+
# Cap amplification: uncapped rescaling compounds
|
| 4139 |
+
# across layers and directions, destroying coherence.
|
| 4140 |
+
# 1.10 keeps per-projection drift bounded while
|
| 4141 |
+
# allowing legitimate norm preservation.
|
| 4142 |
+
if ratio > 1.10:
|
| 4143 |
+
ratio = 1.10
|
| 4144 |
+
W.mul_(ratio)
|
| 4145 |
|
| 4146 |
if is_quantized:
|
| 4147 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
|
|
| 4166 |
new_norm_sq = max(0.0, original_norm_sq - scale * (2 - scale) * coeff_norm_sq)
|
| 4167 |
if new_norm_sq > 0:
|
| 4168 |
import math
|
| 4169 |
+
ratio = math.sqrt(original_norm_sq / new_norm_sq)
|
| 4170 |
+
if ratio > 1.10:
|
| 4171 |
+
ratio = 1.10
|
| 4172 |
+
W.mul_(ratio)
|
| 4173 |
|
| 4174 |
if is_quantized:
|
| 4175 |
AbliterationPipeline._replace_quantized_weight(proj, W)
|
|
|
|
| 4272 |
if norm_preserve and original_norm > 0:
|
| 4273 |
new_norm = W.norm().item()
|
| 4274 |
if new_norm > 0:
|
| 4275 |
+
ratio = original_norm / new_norm
|
| 4276 |
+
if ratio > 1.10:
|
| 4277 |
+
ratio = 1.10
|
| 4278 |
+
W.mul_(ratio)
|
| 4279 |
count += 1
|
| 4280 |
elif W.shape[0] == d.shape[0]:
|
| 4281 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
|
|
| 4285 |
if norm_preserve and original_norm > 0:
|
| 4286 |
new_norm = W.norm().item()
|
| 4287 |
if new_norm > 0:
|
| 4288 |
+
ratio = original_norm / new_norm
|
| 4289 |
+
if ratio > 1.10:
|
| 4290 |
+
ratio = 1.10
|
| 4291 |
+
W.mul_(ratio)
|
| 4292 |
count += 1
|
| 4293 |
|
| 4294 |
if count > 0:
|
|
|
|
| 4860 |
if norm_preserve and original_norm > 0:
|
| 4861 |
new_norm = W.norm().item()
|
| 4862 |
if new_norm > 0:
|
| 4863 |
+
ratio = original_norm / new_norm
|
| 4864 |
+
if ratio > 1.10:
|
| 4865 |
+
ratio = 1.10
|
| 4866 |
+
W.mul_(ratio)
|
| 4867 |
count += 1
|
| 4868 |
elif W.shape[0] == d.shape[0]:
|
| 4869 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
|
|
| 4877 |
if norm_preserve and original_norm > 0:
|
| 4878 |
new_norm = W.norm().item()
|
| 4879 |
if new_norm > 0:
|
| 4880 |
+
ratio = original_norm / new_norm
|
| 4881 |
+
if ratio > 1.10:
|
| 4882 |
+
ratio = 1.10
|
| 4883 |
+
W.mul_(ratio)
|
| 4884 |
count += 1
|
| 4885 |
|
| 4886 |
if is_quantized and count > 0:
|
|
|
|
| 4964 |
if norm_preserve and original_norm > 0:
|
| 4965 |
new_norm = W.norm().item()
|
| 4966 |
if new_norm > 0:
|
| 4967 |
+
ratio = original_norm / new_norm
|
| 4968 |
+
if ratio > 1.10:
|
| 4969 |
+
ratio = 1.10
|
| 4970 |
+
W.mul_(ratio)
|
| 4971 |
count += 1
|
| 4972 |
elif W.shape[0] == d.shape[0]:
|
| 4973 |
original_norm = W.norm().item() if norm_preserve else 0.0
|
|
|
|
| 4981 |
if norm_preserve and original_norm > 0:
|
| 4982 |
new_norm = W.norm().item()
|
| 4983 |
if new_norm > 0:
|
| 4984 |
+
ratio = original_norm / new_norm
|
| 4985 |
+
if ratio > 1.10:
|
| 4986 |
+
ratio = 1.10
|
| 4987 |
+
W.mul_(ratio)
|
| 4988 |
count += 1
|
| 4989 |
|
| 4990 |
if is_quantized and count > 0:
|
tests/test_abliterate.py
CHANGED
|
@@ -255,8 +255,14 @@ class TestProjectOutAdvanced:
|
|
| 255 |
)
|
| 256 |
|
| 257 |
new_norm = module.o_proj.weight.data.norm().item()
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
def test_regularization_partial_removal(self):
|
| 262 |
"""Regularization should preserve some of the refusal component."""
|
|
@@ -319,7 +325,11 @@ class TestProjectOutAdvanced:
|
|
| 319 |
)
|
| 320 |
|
| 321 |
new_norm = module.c_proj.weight.data.norm().item()
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
# ---------------------------------------------------------------------------
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
new_norm = module.o_proj.weight.data.norm().item()
|
| 258 |
+
# With amplification cap (1.10x max), exact norm preservation isn't
|
| 259 |
+
# guaranteed on tiny matrices (hidden_dim=4) where a single direction
|
| 260 |
+
# removes a large fraction of energy. Verify the norm is closer to
|
| 261 |
+
# original than the un-preserved norm would be (i.e. cap is working).
|
| 262 |
+
without_preserve_norm_sq = original_norm ** 2 - (module.o_proj.weight.data @ direction).pow(2).sum().item()
|
| 263 |
+
# The new norm should be >= the un-preserved norm (cap restores some)
|
| 264 |
+
assert new_norm >= original_norm * 0.85, \
|
| 265 |
+
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
|
| 266 |
|
| 267 |
def test_regularization_partial_removal(self):
|
| 268 |
"""Regularization should preserve some of the refusal component."""
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
new_norm = module.c_proj.weight.data.norm().item()
|
| 328 |
+
# With amplification cap (1.10x max), exact norm preservation isn't
|
| 329 |
+
# guaranteed on tiny matrices where a single direction removes a large
|
| 330 |
+
# fraction of energy.
|
| 331 |
+
assert new_norm >= original_norm * 0.80, \
|
| 332 |
+
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
|
| 333 |
|
| 334 |
|
| 335 |
# ---------------------------------------------------------------------------
|