Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """OBLITERATUS vs SOTA — Head-to-Head Benchmark Comparison. | |
| Runs faithful reproductions of competing abliteration methods against | |
| OBLITERATUS variants on any specified model, producing publication-ready | |
| comparison tables with standardized community metrics. | |
| Baselines included: | |
| 1. FailSpy/abliterator (2024) — Community workhorse baseline | |
| 2. Gabliteration (Gülmez 2026) — SVD multi-direction + ridge regularization | |
| 3. Heretic / p-e-w (2025) — Bayesian TPE auto-tuning (current SOTA for quality) | |
| 4. Wollschlager RDO (ICML 2025) — Gradient-based direction optimization | |
| OBLITERATUS variants: | |
| 5. OBLITERATUS surgical — Full SOTA MoE-aware pipeline | |
| 6. OBLITERATUS informed — Analysis-guided auto-configuration | |
| 7. OBLITERATUS optimized — Bayesian + whitened SVD + SAE (max OBLITERATUS) | |
| Evaluation protocol (Heretic community standard): | |
| - Refusal rate via substring + prefix detection | |
| - First-token KL divergence on harmless prompts | |
| - Capability probes (knowledge, truthfulness, math reasoning) | |
| - Optional: HarmBench ASR, lm-eval-harness benchmarks | |
| Usage: | |
| # Quick comparison (small model, few prompts) | |
| python scripts/benchmark_sota_comparison.py --model Qwen/Qwen2.5-1.5B-Instruct --quick | |
| # Full comparison on 8B model | |
| python scripts/benchmark_sota_comparison.py --model meta-llama/Llama-3.1-8B-Instruct | |
| # Specific baselines only | |
| python scripts/benchmark_sota_comparison.py --methods failspy heretic surgical | |
| # Custom prompt count and output | |
| python scripts/benchmark_sota_comparison.py --prompts 100 --output results.json | |
| # Include full Heretic evaluation protocol (HarmBench, lm-eval) | |
| python scripts/benchmark_sota_comparison.py --full-eval | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import torch | |
| # Ensure the project root is on sys.path | |
| project_root = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from obliteratus.abliterate import ( # noqa: E402 | |
| AbliterationPipeline, | |
| METHODS, | |
| HARMFUL_PROMPTS, | |
| HARMLESS_PROMPTS, | |
| ) | |
| from obliteratus.evaluation.benchmarks import BenchmarkRunner # noqa: E402 | |
| # ── All methods available for comparison ────────────────────────────── | |
| # Baselines (reproductions of competing methods) | |
| BASELINE_METHODS = ["failspy", "gabliteration", "heretic", "rdo"] | |
| # OBLITERATUS variants | |
| OBLITERATUS_METHODS = ["surgical", "informed", "optimized"] | |
| # Default comparison set | |
| DEFAULT_METHODS = BASELINE_METHODS + OBLITERATUS_METHODS | |
| # Quick mode: skip slow methods (Bayesian optimization) | |
| QUICK_METHODS = ["failspy", "gabliteration", "rdo", "surgical"] | |
| class MethodResult: | |
| """Results for a single method run.""" | |
| method: str | |
| label: str | |
| refusal_rate: float = 0.0 | |
| kl_divergence: float = 0.0 | |
| knowledge_score: float = 0.0 | |
| truthfulness_score: float = 0.0 | |
| math_score: float = 0.0 | |
| ablation_time_s: float = 0.0 | |
| peak_gpu_mb: float = 0.0 | |
| n_layers_modified: int = 0 | |
| n_projections: int = 0 | |
| error: str | None = None | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="OBLITERATUS vs SOTA — Head-to-Head Benchmark", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| parser.add_argument( | |
| "--model", default="Qwen/Qwen2.5-1.5B-Instruct", | |
| help="Model to benchmark (default: Qwen/Qwen2.5-1.5B-Instruct)", | |
| ) | |
| parser.add_argument( | |
| "--methods", nargs="+", default=None, | |
| help=f"Methods to compare (default: all). Available: {', '.join(DEFAULT_METHODS)}", | |
| ) | |
| parser.add_argument( | |
| "--prompts", type=int, default=None, | |
| help="Number of prompts for abliteration (default: auto based on --quick)", | |
| ) | |
| parser.add_argument( | |
| "--eval-prompts", type=int, default=20, | |
| help="Number of prompts for refusal rate evaluation (default: 20)", | |
| ) | |
| parser.add_argument( | |
| "--kl-prompts", type=int, default=10, | |
| help="Number of harmless prompts for KL divergence (default: 10)", | |
| ) | |
| parser.add_argument( | |
| "--output", default=None, | |
| help="Output JSON file for results (default: stdout only)", | |
| ) | |
| parser.add_argument( | |
| "--quick", action="store_true", | |
| help="Quick mode: fewer prompts, skip slow methods (Bayesian opt)", | |
| ) | |
| parser.add_argument( | |
| "--full-eval", action="store_true", | |
| help="Run full Heretic evaluation protocol (HarmBench ASR, lm-eval)", | |
| ) | |
| parser.add_argument( | |
| "--quantization", default=None, choices=["4bit", "8bit"], | |
| help="Quantization for model loading (default: auto)", | |
| ) | |
| parser.add_argument( | |
| "--device", default="auto", | |
| help="Device (default: auto)", | |
| ) | |
| parser.add_argument( | |
| "--dtype", default="float16", | |
| help="Data type (default: float16)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", default="/tmp/obliteratus_benchmark", | |
| help="Directory for abliterated model outputs (default: /tmp/obliteratus_benchmark)", | |
| ) | |
| parser.add_argument( | |
| "--keep-models", action="store_true", | |
| help="Keep abliterated models after benchmark (default: delete to save disk)", | |
| ) | |
| return parser.parse_args() | |
| def measure_refusal_rate( | |
| pipeline: AbliterationPipeline, | |
| prompts: list[str], | |
| max_new_tokens: int = 128, | |
| ) -> float: | |
| """Measure refusal rate on a set of prompts.""" | |
| from obliteratus.evaluation.advanced_metrics import refusal_rate | |
| model = pipeline.handle.model | |
| tokenizer = pipeline.handle.tokenizer | |
| responses = [] | |
| for prompt in prompts: | |
| try: | |
| if pipeline.use_chat_template and hasattr(tokenizer, "apply_chat_template"): | |
| messages = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| else: | |
| text = prompt | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=1.0, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ) | |
| responses.append(response) | |
| except Exception as e: | |
| responses.append(f"[ERROR: {e}]") | |
| return refusal_rate(responses, mode="combined") | |
| def measure_kl_divergence( | |
| pipeline: AbliterationPipeline, | |
| original_logits: dict[int, torch.Tensor], | |
| prompts: list[str], | |
| ) -> float: | |
| """Measure first-token KL divergence vs original model logits.""" | |
| import torch.nn.functional as F | |
| model = pipeline.handle.model | |
| tokenizer = pipeline.handle.tokenizer | |
| kl_values = [] | |
| for i, prompt in enumerate(prompts): | |
| if i not in original_logits: | |
| continue | |
| try: | |
| if pipeline.use_chat_template and hasattr(tokenizer, "apply_chat_template"): | |
| messages = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| else: | |
| text = prompt | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| new_logits = outputs.logits[0, -1, :].float().cpu() | |
| orig = original_logits[i].float() | |
| log_p = F.log_softmax(orig, dim=-1) | |
| log_q = F.log_softmax(new_logits, dim=-1) | |
| kl = F.kl_div(log_q, log_p.exp(), reduction="sum").item() | |
| if kl >= 0: # KL should be non-negative | |
| kl_values.append(kl) | |
| except Exception: | |
| pass | |
| return sum(kl_values) / len(kl_values) if kl_values else float("nan") | |
| def collect_baseline_logits( | |
| pipeline: AbliterationPipeline, | |
| prompts: list[str], | |
| ) -> dict[int, torch.Tensor]: | |
| """Collect first-token logits from the original (pre-abliteration) model.""" | |
| model = pipeline.handle.model | |
| tokenizer = pipeline.handle.tokenizer | |
| logits = {} | |
| for i, prompt in enumerate(prompts): | |
| try: | |
| if pipeline.use_chat_template and hasattr(tokenizer, "apply_chat_template"): | |
| messages = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| else: | |
| text = prompt | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits[i] = outputs.logits[0, -1, :].float().cpu() | |
| except Exception: | |
| pass | |
| return logits | |
| def run_single_method( | |
| model_name: str, | |
| method: str, | |
| harmful_prompts: list[str], | |
| harmless_prompts: list[str], | |
| eval_harmful: list[str], | |
| eval_harmless: list[str], | |
| args: argparse.Namespace, | |
| ) -> MethodResult: | |
| """Run a single abliteration method and collect metrics.""" | |
| label = METHODS.get(method, {}).get("label", method) | |
| result = MethodResult(method=method, label=label) | |
| print(f"\n{'='*70}") | |
| print(f" Method: {label}") | |
| print(f"{'='*70}") | |
| output_dir = Path(args.output_dir) / method | |
| try: | |
| # Track GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.time() | |
| # Build pipeline with method-specific config | |
| # For 'informed', use InformedAbliterationPipeline | |
| if method == "informed": | |
| from obliteratus.informed_pipeline import InformedAbliterationPipeline | |
| pipeline = InformedAbliterationPipeline( | |
| model_name=model_name, | |
| output_dir=str(output_dir), | |
| device=args.device, | |
| dtype=args.dtype, | |
| quantization=args.quantization, | |
| harmful_prompts=harmful_prompts, | |
| harmless_prompts=harmless_prompts, | |
| on_log=lambda msg: print(f" {msg}"), | |
| ) | |
| else: | |
| pipeline = AbliterationPipeline( | |
| model_name=model_name, | |
| output_dir=str(output_dir), | |
| device=args.device, | |
| dtype=args.dtype, | |
| method=method, | |
| quantization=args.quantization, | |
| harmful_prompts=harmful_prompts, | |
| harmless_prompts=harmless_prompts, | |
| use_chat_template=True, | |
| on_log=lambda msg: print(f" {msg}"), | |
| ) | |
| # Phase 1: Load model + collect baseline KL logits | |
| print(" Loading model...") | |
| pipeline._summon() | |
| print(" Collecting baseline logits for KL divergence...") | |
| baseline_logits = collect_baseline_logits(pipeline, eval_harmless) | |
| # Phase 2: Run abliteration pipeline | |
| print(" Probing activations...") | |
| pipeline._probe() | |
| print(" Extracting refusal directions...") | |
| pipeline._distill() | |
| result.n_layers_modified = len(pipeline._strong_layers) | |
| print(f" Excising refusal ({result.n_layers_modified} layers)...") | |
| pipeline._excise() | |
| result.ablation_time_s = time.time() - t0 | |
| # Track GPU memory | |
| if torch.cuda.is_available(): | |
| result.peak_gpu_mb = torch.cuda.max_memory_allocated() / 1e6 | |
| # Phase 3: Evaluate | |
| print(f" Evaluating refusal rate ({len(eval_harmful)} prompts)...") | |
| result.refusal_rate = measure_refusal_rate(pipeline, eval_harmful) | |
| print(f" Evaluating KL divergence ({len(eval_harmless)} prompts)...") | |
| result.kl_divergence = measure_kl_divergence(pipeline, baseline_logits, eval_harmless) | |
| # Capability probes | |
| print(" Running capability probes...") | |
| try: | |
| runner = BenchmarkRunner( | |
| pipeline.handle.model, | |
| pipeline.handle.tokenizer, | |
| ) | |
| bench_result = runner.run_all() | |
| result.knowledge_score = bench_result.knowledge.accuracy if bench_result.knowledge else 0.0 | |
| result.truthfulness_score = bench_result.truthfulness.accuracy if bench_result.truthfulness else 0.0 | |
| result.math_score = bench_result.math.accuracy if bench_result.math else 0.0 | |
| except Exception as e: | |
| print(f" Warning: capability probes failed: {e}") | |
| # Optional: full Heretic evaluation | |
| if args.full_eval: | |
| print(" Running full Heretic evaluation protocol...") | |
| try: | |
| from obliteratus.evaluation.heretic_eval import run_full_heretic_eval | |
| heretic_result = run_full_heretic_eval( | |
| model=pipeline.handle.model, | |
| tokenizer=pipeline.handle.tokenizer, | |
| original_model=None, # Would need original for full comparison | |
| ) | |
| print(f" Heretic eval: ASR={heretic_result.harmbench_asr:.1%}, " | |
| f"JB_refusal={heretic_result.jailbreakbench_refusal_rate:.1%}") | |
| except Exception as e: | |
| print(f" Warning: Heretic eval failed: {e}") | |
| print(f" ✓ Complete: refusal={result.refusal_rate:.1%}, KL={result.kl_divergence:.4f}, " | |
| f"time={result.ablation_time_s:.1f}s") | |
| except Exception as e: | |
| result.error = str(e) | |
| print(f" ✗ FAILED: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| # Clean up to free GPU memory for next method | |
| if not args.keep_models and output_dir.exists(): | |
| shutil.rmtree(output_dir, ignore_errors=True) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return result | |
| def format_comparison_table(results: list[MethodResult]) -> str: | |
| """Format results as a publication-ready comparison table.""" | |
| lines = [] | |
| # Header | |
| lines.append("") | |
| lines.append("=" * 115) | |
| lines.append("OBLITERATUS vs SOTA — Head-to-Head Benchmark Comparison") | |
| lines.append("=" * 115) | |
| lines.append("") | |
| # Separator between baselines and OBLITERATUS | |
| lines.append(f"{'Method':<35} {'Refusal↓':>10} {'KL↓':>10} {'Know↑':>8} {'Truth↑':>8} {'Math↑':>8} {'Time':>8} {'Layers':>7}") | |
| lines.append("-" * 115) | |
| # Baselines first | |
| baseline_results = [r for r in results if r.method in BASELINE_METHODS] | |
| obliteratus_results = [r for r in results if r.method not in BASELINE_METHODS] | |
| if baseline_results: | |
| lines.append(" BASELINES:") | |
| for r in baseline_results: | |
| if r.error: | |
| lines.append(f" {r.label:<33} {'FAILED':>10} {r.error[:60]}") | |
| else: | |
| lines.append( | |
| f" {r.label:<33} {r.refusal_rate:>9.1%} {r.kl_divergence:>10.4f} " | |
| f"{r.knowledge_score:>7.1%} {r.truthfulness_score:>7.1%} {r.math_score:>7.1%} " | |
| f"{r.ablation_time_s:>7.1f}s {r.n_layers_modified:>6}" | |
| ) | |
| if obliteratus_results: | |
| lines.append(" OBLITERATUS:") | |
| for r in obliteratus_results: | |
| if r.error: | |
| lines.append(f" {r.label:<33} {'FAILED':>10} {r.error[:60]}") | |
| else: | |
| lines.append( | |
| f" {r.label:<33} {r.refusal_rate:>9.1%} {r.kl_divergence:>10.4f} " | |
| f"{r.knowledge_score:>7.1%} {r.truthfulness_score:>7.1%} {r.math_score:>7.1%} " | |
| f"{r.ablation_time_s:>7.1f}s {r.n_layers_modified:>6}" | |
| ) | |
| lines.append("-" * 115) | |
| # Best values | |
| successful = [r for r in results if r.error is None] | |
| if successful: | |
| best_refusal = min(successful, key=lambda r: r.refusal_rate) | |
| best_kl = min(successful, key=lambda r: r.kl_divergence if r.kl_divergence == r.kl_divergence else float("inf")) | |
| best_knowledge = max(successful, key=lambda r: r.knowledge_score) | |
| lines.append(f" Best refusal removal: {best_refusal.label} ({best_refusal.refusal_rate:.1%})") | |
| lines.append(f" Best quality preservation: {best_kl.label} (KL={best_kl.kl_divergence:.4f})") | |
| lines.append(f" Best knowledge retention: {best_knowledge.label} ({best_knowledge.knowledge_score:.1%})") | |
| lines.append("=" * 115) | |
| lines.append("") | |
| # Metric interpretation guide | |
| lines.append("Metrics:") | |
| lines.append(" Refusal↓ = fraction of harmful prompts still refused (lower = more effective abliteration)") | |
| lines.append(" KL↓ = first-token KL divergence on harmless prompts (lower = better quality preservation)") | |
| lines.append(" Know↑ = MMLU-style knowledge probe accuracy (higher = better capability)") | |
| lines.append(" Truth↑ = TruthfulQA-style probe accuracy (higher = better calibration)") | |
| lines.append(" Math↑ = GSM8K-style math reasoning accuracy (higher = better reasoning)") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def main(): | |
| args = parse_args() | |
| print("=" * 70) | |
| print(" OBLITERATUS vs SOTA — Head-to-Head Benchmark") | |
| print(f" Model: {args.model}") | |
| print("=" * 70) | |
| # Determine methods to run | |
| methods = args.methods or (QUICK_METHODS if args.quick else DEFAULT_METHODS) | |
| # Validate methods | |
| valid_methods = set(METHODS.keys()) | {"informed"} | |
| for m in methods: | |
| if m not in valid_methods: | |
| print(f"Error: unknown method '{m}'. Available: {sorted(valid_methods)}") | |
| sys.exit(1) | |
| print(f" Methods: {', '.join(methods)}") | |
| # Determine prompt counts | |
| n_prompts = args.prompts or (50 if args.quick else 128) | |
| n_prompts = min(n_prompts, len(HARMFUL_PROMPTS), len(HARMLESS_PROMPTS)) | |
| harmful_prompts = HARMFUL_PROMPTS[:n_prompts] | |
| harmless_prompts = HARMLESS_PROMPTS[:n_prompts] | |
| # Evaluation subsets (separate from training prompts for fair comparison) | |
| eval_harmful = HARMFUL_PROMPTS[n_prompts:n_prompts + args.eval_prompts] | |
| if len(eval_harmful) < args.eval_prompts: | |
| # Wrap around if not enough prompts | |
| eval_harmful = HARMFUL_PROMPTS[:args.eval_prompts] | |
| eval_harmless = HARMLESS_PROMPTS[n_prompts:n_prompts + args.kl_prompts] | |
| if len(eval_harmless) < args.kl_prompts: | |
| eval_harmless = HARMLESS_PROMPTS[:args.kl_prompts] | |
| print(f" Abliteration prompts: {n_prompts} harmful + {n_prompts} harmless") | |
| print(f" Evaluation prompts: {len(eval_harmful)} harmful, {len(eval_harmless)} harmless") | |
| print() | |
| # Run each method | |
| results: list[MethodResult] = [] | |
| for method in methods: | |
| result = run_single_method( | |
| model_name=args.model, | |
| method=method, | |
| harmful_prompts=harmful_prompts, | |
| harmless_prompts=harmless_prompts, | |
| eval_harmful=eval_harmful, | |
| eval_harmless=eval_harmless, | |
| args=args, | |
| ) | |
| results.append(result) | |
| # Print comparison table | |
| table = format_comparison_table(results) | |
| print(table) | |
| # Save results | |
| if args.output: | |
| output_path = Path(args.output) | |
| output_data = { | |
| "model": args.model, | |
| "n_prompts": n_prompts, | |
| "n_eval_harmful": len(eval_harmful), | |
| "n_eval_harmless": len(eval_harmless), | |
| "methods": [asdict(r) for r in results], | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), | |
| } | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(output_data, indent=2, default=str)) | |
| print(f"Results saved to {output_path}") | |
| if __name__ == "__main__": | |
| main() | |