Spaces:
Running on Zero
Running on Zero
Upload 134 files
Browse files- app.py +453 -217
- obliteratus/.DS_Store +0 -0
- obliteratus/abliterate.py +43 -42
- obliteratus/analysis/sae_abliteration.py +4 -3
- obliteratus/device.py +305 -0
- obliteratus/evaluation/benchmarks.py +2 -2
- obliteratus/evaluation/heretic_eval.py +6 -8
- obliteratus/interactive.py +14 -9
- obliteratus/mlx_backend.py +469 -0
- obliteratus/models/loader.py +60 -62
- obliteratus/reproducibility.py +2 -2
- obliteratus/tourney.py +2 -3
- requirements-apple.txt +8 -0
app.py
CHANGED
|
@@ -57,6 +57,7 @@ if "HF_HOME" not in os.environ:
|
|
| 57 |
|
| 58 |
import gradio as gr
|
| 59 |
import torch
|
|
|
|
| 60 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 61 |
|
| 62 |
# ── ZeroGPU support ─────────────────────────────────────────────────
|
|
@@ -399,6 +400,213 @@ def _validate_hub_repo(hub_repo: str) -> str:
|
|
| 399 |
return ""
|
| 400 |
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
PROMPT_VOLUMES = {
|
| 403 |
"33 (fast)": 33,
|
| 404 |
"66 (better signal)": 66,
|
|
@@ -447,25 +655,11 @@ def _should_quantize(model_id: str, is_preset: bool = False) -> str | None:
|
|
| 447 |
# ---------------------------------------------------------------------------
|
| 448 |
|
| 449 |
def _clear_gpu():
|
| 450 |
-
"""Free GPU memory. Resilient to
|
| 451 |
with _lock:
|
| 452 |
_state["model"] = None
|
| 453 |
_state["tokenizer"] = None
|
| 454 |
-
|
| 455 |
-
if torch.cuda.is_available():
|
| 456 |
-
try:
|
| 457 |
-
torch.cuda.empty_cache()
|
| 458 |
-
except Exception:
|
| 459 |
-
# CUDA context may be poisoned after an illegal-address error;
|
| 460 |
-
# attempt a device reset so subsequent loads can succeed.
|
| 461 |
-
try:
|
| 462 |
-
torch.cuda.synchronize()
|
| 463 |
-
except Exception:
|
| 464 |
-
pass
|
| 465 |
-
try:
|
| 466 |
-
torch.cuda.reset_peak_memory_stats()
|
| 467 |
-
except Exception:
|
| 468 |
-
pass
|
| 469 |
|
| 470 |
|
| 471 |
def _install_steering_hooks(model, steering_meta: dict) -> int:
|
|
@@ -589,16 +783,16 @@ def _cleanup_disk():
|
|
| 589 |
# ---------------------------------------------------------------------------
|
| 590 |
|
| 591 |
def _get_vram_html() -> str:
|
| 592 |
-
"""Return an HTML snippet showing GPU
|
| 593 |
-
if not
|
| 594 |
return (
|
| 595 |
'<div style="text-align:center;color:#4a5568;font-size:0.72rem;'
|
| 596 |
'letter-spacing:1px;margin-top:6px;">CPU ONLY — NO GPU DETECTED</div>'
|
| 597 |
)
|
| 598 |
try:
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
total =
|
| 602 |
pct = (used / total * 100) if total > 0 else 0
|
| 603 |
# Color shifts from green → yellow → red
|
| 604 |
if pct < 50:
|
|
@@ -607,12 +801,17 @@ def _get_vram_html() -> str:
|
|
| 607 |
bar_color = "#ffcc00"
|
| 608 |
else:
|
| 609 |
bar_color = "#ff003c"
|
| 610 |
-
device_name =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
return (
|
| 612 |
f'<div style="margin:6px auto 0;max-width:480px;">'
|
| 613 |
f'<div style="display:flex;justify-content:space-between;font-size:0.68rem;'
|
| 614 |
f'color:#4a5568;letter-spacing:1px;margin-bottom:2px;">'
|
| 615 |
-
f'<span>
|
| 616 |
f'<span>{used:.1f} / {total:.1f} GB ({pct:.0f}%)</span></div>'
|
| 617 |
f'<div style="background:#0a0a0f;border:1px solid #1a1f2e;border-radius:3px;'
|
| 618 |
f'height:10px;overflow:hidden;">'
|
|
@@ -620,11 +819,11 @@ def _get_vram_html() -> str:
|
|
| 620 |
f'box-shadow:0 0 6px {bar_color};transition:width 0.5s ease;"></div></div>'
|
| 621 |
f'<div style="display:flex;justify-content:space-between;font-size:0.6rem;'
|
| 622 |
f'color:#333;margin-top:1px;">'
|
| 623 |
-
f'
|
| 624 |
f'</div>'
|
| 625 |
)
|
| 626 |
except Exception:
|
| 627 |
-
return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">
|
| 628 |
|
| 629 |
|
| 630 |
# ---------------------------------------------------------------------------
|
|
@@ -1067,8 +1266,7 @@ def benchmark(
|
|
| 1067 |
pass
|
| 1068 |
pipeline_ref[0] = None
|
| 1069 |
gc.collect()
|
| 1070 |
-
|
| 1071 |
-
torch.cuda.empty_cache()
|
| 1072 |
|
| 1073 |
yield (
|
| 1074 |
f"**{method_key} complete** ({mi + 1}/{len(methods_to_test)}) \u2014 {_bench_elapsed()}",
|
|
@@ -1418,8 +1616,7 @@ def benchmark_multi_model(
|
|
| 1418 |
pass
|
| 1419 |
pipeline_ref[0] = None
|
| 1420 |
gc.collect()
|
| 1421 |
-
|
| 1422 |
-
torch.cuda.empty_cache()
|
| 1423 |
|
| 1424 |
yield (
|
| 1425 |
f"**{model_id} complete** ({mi + 1}/{len(model_choices)}) \u2014 {_mm_elapsed()}",
|
|
@@ -1518,7 +1715,6 @@ def _format_multi_model_results(results: list[dict], context: dict | None = None
|
|
| 1518 |
|
| 1519 |
@spaces.GPU(duration=300)
|
| 1520 |
def obliterate(model_choice: str, method_choice: str,
|
| 1521 |
-
hub_auto_push: bool, hub_repo: str,
|
| 1522 |
prompt_volume_choice: str, dataset_source_choice: str,
|
| 1523 |
custom_harmful: str, custom_harmless: str,
|
| 1524 |
# Advanced params (sliders)
|
|
@@ -1551,14 +1747,6 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1551 |
model_id = MODELS.get(model_choice, model_choice)
|
| 1552 |
is_preset = model_choice in MODELS
|
| 1553 |
method = METHODS.get(method_choice, "advanced")
|
| 1554 |
-
# Resolve push-to-hub: explicit repo overrides auto-naming
|
| 1555 |
-
_hub_override = hub_repo.strip() if hub_repo and hub_repo.strip() else None
|
| 1556 |
-
if _hub_override:
|
| 1557 |
-
push_to_hub = _hub_override
|
| 1558 |
-
elif hub_auto_push:
|
| 1559 |
-
push_to_hub = "auto" # resolved to {user}/{model}-OBLITERATED at push time
|
| 1560 |
-
else:
|
| 1561 |
-
push_to_hub = None
|
| 1562 |
prompt_volume = PROMPT_VOLUMES.get(prompt_volume_choice, 33)
|
| 1563 |
|
| 1564 |
# Resolve "adaptive" → telemetry-recommended method for this model
|
|
@@ -1606,26 +1794,6 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1606 |
)
|
| 1607 |
return
|
| 1608 |
|
| 1609 |
-
# Early validation: Hub repo format + token availability
|
| 1610 |
-
# Resolve which token to use: user's own HF_TOKEN, or the shared community token.
|
| 1611 |
-
_user_token = os.environ.get("HF_TOKEN")
|
| 1612 |
-
_hub_token = _user_token or _HUB_COMMUNITY_TOKEN
|
| 1613 |
-
_hub_org = None if _user_token else _HUB_COMMUNITY_ORG # community org only when using shared token
|
| 1614 |
-
if push_to_hub:
|
| 1615 |
-
if push_to_hub != "auto" and not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', push_to_hub):
|
| 1616 |
-
yield (
|
| 1617 |
-
"**Error:** Invalid Hub repo format. Use `username/model-name`.",
|
| 1618 |
-
"", gr.update(), gr.update(), gr.update(), gr.update(),
|
| 1619 |
-
)
|
| 1620 |
-
return
|
| 1621 |
-
if not _hub_token:
|
| 1622 |
-
yield (
|
| 1623 |
-
"**Error:** No Hub token available. Set HF_TOKEN or OBLITERATUS_HUB_TOKEN "
|
| 1624 |
-
"as an environment variable or Space secret.",
|
| 1625 |
-
"", gr.update(), gr.update(), gr.update(), gr.update(),
|
| 1626 |
-
)
|
| 1627 |
-
return
|
| 1628 |
-
|
| 1629 |
# Resolve dataset source — custom prompts override the dropdown
|
| 1630 |
use_custom = custom_harmful and custom_harmful.strip()
|
| 1631 |
dataset_key = get_source_key_from_label(dataset_source_choice) if dataset_source_choice else "builtin"
|
|
@@ -1699,9 +1867,6 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1699 |
output_dir=save_dir,
|
| 1700 |
device="auto",
|
| 1701 |
dtype="float16",
|
| 1702 |
-
push_to_hub=push_to_hub,
|
| 1703 |
-
hub_token=_hub_token,
|
| 1704 |
-
hub_community_org=_hub_org,
|
| 1705 |
quantization=quantization,
|
| 1706 |
trust_remote_code=is_preset,
|
| 1707 |
harmful_prompts=harmful_all[:n],
|
|
@@ -1719,9 +1884,6 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1719 |
device="auto",
|
| 1720 |
dtype="float16",
|
| 1721 |
method=method,
|
| 1722 |
-
push_to_hub=push_to_hub,
|
| 1723 |
-
hub_token=_hub_token,
|
| 1724 |
-
hub_community_org=_hub_org,
|
| 1725 |
quantization=quantization,
|
| 1726 |
trust_remote_code=is_preset,
|
| 1727 |
harmful_prompts=harmful_all[:n],
|
|
@@ -1774,12 +1936,6 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1774 |
log_lines.append(f"Dataset: {source_label}")
|
| 1775 |
vol_label = "all" if prompt_volume == -1 else str(prompt_volume)
|
| 1776 |
log_lines.append(f"Prompt volume: {vol_label} pairs")
|
| 1777 |
-
if push_to_hub:
|
| 1778 |
-
if push_to_hub == "auto":
|
| 1779 |
-
_ns = _hub_org or "{you}"
|
| 1780 |
-
log_lines.append(f"Push to Hub: auto ({_ns}/{{model}}-OBLITERATED)")
|
| 1781 |
-
else:
|
| 1782 |
-
log_lines.append(f"Push to Hub: {push_to_hub}")
|
| 1783 |
if quantization:
|
| 1784 |
log_lines.append(f"Quantization: {quantization} (auto-detected for GPU fit)")
|
| 1785 |
log_lines.append("")
|
|
@@ -2118,11 +2274,11 @@ def chat_respond(message: str, history: list[dict], system_prompt: str,
|
|
| 2118 |
_needs_reload = model is None or tokenizer is None
|
| 2119 |
if not _needs_reload:
|
| 2120 |
try:
|
| 2121 |
-
|
| 2122 |
-
if
|
| 2123 |
_needs_reload = True
|
| 2124 |
-
elif
|
| 2125 |
-
model.to(
|
| 2126 |
except Exception:
|
| 2127 |
_needs_reload = True
|
| 2128 |
|
|
@@ -2552,11 +2708,11 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
|
| 2552 |
_needs_reload = abliterated_model is None or tokenizer is None
|
| 2553 |
if not _needs_reload:
|
| 2554 |
try:
|
| 2555 |
-
|
| 2556 |
-
if
|
| 2557 |
_needs_reload = True
|
| 2558 |
-
elif
|
| 2559 |
-
abliterated_model.to(
|
| 2560 |
except Exception:
|
| 2561 |
_needs_reload = True
|
| 2562 |
|
|
@@ -2689,8 +2845,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
|
| 2689 |
abl_device = next(abliterated_model.parameters()).device
|
| 2690 |
abliterated_model.to("cpu")
|
| 2691 |
gc.collect()
|
| 2692 |
-
|
| 2693 |
-
torch.cuda.empty_cache()
|
| 2694 |
|
| 2695 |
model_id = MODELS.get(model_name, model_name)
|
| 2696 |
# Only trust remote code for known preset models, not arbitrary user-supplied IDs
|
|
@@ -2742,8 +2897,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
|
| 2742 |
# Free the original model
|
| 2743 |
del original_model
|
| 2744 |
gc.collect()
|
| 2745 |
-
|
| 2746 |
-
torch.cuda.empty_cache()
|
| 2747 |
|
| 2748 |
except Exception as e:
|
| 2749 |
original_response = f"*Could not load original model for comparison: {e}*"
|
|
@@ -2752,7 +2906,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
|
|
| 2752 |
# Use torch.device("cuda") rather than the captured abl_device, since
|
| 2753 |
# on ZeroGPU the original device reference may point to a stale context.
|
| 2754 |
try:
|
| 2755 |
-
restore_device = torch.device(
|
| 2756 |
abliterated_model.to(restore_device)
|
| 2757 |
except Exception:
|
| 2758 |
pass # If GPU restore fails, model stays on CPU (still usable)
|
|
@@ -2870,8 +3024,7 @@ def strength_sweep(model_choice: str, method_choice: str,
|
|
| 2870 |
|
| 2871 |
# Cleanup between runs
|
| 2872 |
gc.collect()
|
| 2873 |
-
|
| 2874 |
-
torch.cuda.empty_cache()
|
| 2875 |
|
| 2876 |
# Generate dose-response curve
|
| 2877 |
gallery = None
|
|
@@ -2963,6 +3116,117 @@ def _format_sweep_results(results: list[dict]) -> str:
|
|
| 2963 |
return "\n".join(lines)
|
| 2964 |
|
| 2965 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2966 |
# ---------------------------------------------------------------------------
|
| 2967 |
# Export Research Artifacts
|
| 2968 |
# ---------------------------------------------------------------------------
|
|
@@ -3523,20 +3787,10 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
|
|
| 3523 |
lines=5,
|
| 3524 |
)
|
| 3525 |
|
| 3526 |
-
|
| 3527 |
-
|
| 3528 |
-
|
| 3529 |
-
|
| 3530 |
-
info=f"Pushes your model to {_HUB_COMMUNITY_ORG}/{{model}}-OBLITERATED on HF Hub. "
|
| 3531 |
-
"No token needed — works out of the box!",
|
| 3532 |
-
)
|
| 3533 |
-
hub_repo = gr.Textbox(
|
| 3534 |
-
label="Push to Hub (optional override)",
|
| 3535 |
-
placeholder="auto-filled when checkbox is ticked, or type your own",
|
| 3536 |
-
info="Leave blank with checkbox ticked for auto-naming, "
|
| 3537 |
-
"or enter a custom repo ID (e.g. your-username/my-model).",
|
| 3538 |
-
)
|
| 3539 |
-
hub_warning_md = gr.Markdown("")
|
| 3540 |
|
| 3541 |
# ── Advanced Settings (auto-populated from method preset) ────
|
| 3542 |
_defaults = _get_preset_defaults("advanced (recommended)")
|
|
@@ -4168,33 +4422,19 @@ tradeoff point where refusal is minimized with minimal capability damage.
|
|
| 4168 |
with gr.Tab("Tourney", id="tourney"):
|
| 4169 |
gr.Markdown("""### March Madness Tournament
|
| 4170 |
Pit **all abliteration methods** against each other in elimination rounds.
|
| 4171 |
-
The winner
|
| 4172 |
|
| 4173 |
**Round 1 — Qualifiers:** All methods, reduced prompts. Bottom half eliminated.
|
| 4174 |
**Round 2 — Semifinals:** Survivors, full prompts. Bottom half eliminated.
|
| 4175 |
**Round 3 — Finals:** Top contenders, maximum prompts. Champion crowned.
|
| 4176 |
""")
|
| 4177 |
-
|
| 4178 |
-
|
| 4179 |
-
|
| 4180 |
-
|
| 4181 |
-
|
| 4182 |
-
|
| 4183 |
-
|
| 4184 |
-
allow_custom_value=True,
|
| 4185 |
-
)
|
| 4186 |
-
with gr.Column(scale=1):
|
| 4187 |
-
tourney_hub_org = gr.Textbox(
|
| 4188 |
-
label="HF Hub Org (optional)",
|
| 4189 |
-
placeholder="my-org",
|
| 4190 |
-
info="Push winner to hub-org/model-name-OBLITERATED",
|
| 4191 |
-
)
|
| 4192 |
-
with gr.Column(scale=1):
|
| 4193 |
-
tourney_hub_repo = gr.Textbox(
|
| 4194 |
-
label="HF Hub Repo (optional)",
|
| 4195 |
-
placeholder="org/repo-name",
|
| 4196 |
-
info="Overrides Hub Org — full repo ID",
|
| 4197 |
-
)
|
| 4198 |
|
| 4199 |
with gr.Accordion("Advanced Settings", open=False):
|
| 4200 |
with gr.Row():
|
|
@@ -4223,97 +4463,12 @@ The winner gets auto-pushed to HuggingFace Hub.
|
|
| 4223 |
interactive=False,
|
| 4224 |
)
|
| 4225 |
|
| 4226 |
-
|
| 4227 |
-
|
|
|
|
| 4228 |
tourney_dataset_dd, tourney_quant_dd],
|
| 4229 |
outputs=[tourney_status, tourney_bracket, tourney_log],
|
| 4230 |
)
|
| 4231 |
-
def run_tourney(model_choice, hub_org, hub_repo, dataset, quantization):
|
| 4232 |
-
if not model_choice or not model_choice.strip():
|
| 4233 |
-
yield "**Error:** Select a model first.", "", ""
|
| 4234 |
-
return
|
| 4235 |
-
|
| 4236 |
-
from obliteratus.tourney import TourneyRunner, render_bracket
|
| 4237 |
-
|
| 4238 |
-
# Resolve display label → HuggingFace model ID
|
| 4239 |
-
model_id = model_choice.strip()
|
| 4240 |
-
if model_id in MODELS:
|
| 4241 |
-
model_id = MODELS[model_id]
|
| 4242 |
-
|
| 4243 |
-
hub_org_val = hub_org.strip() if hub_org and hub_org.strip() else None
|
| 4244 |
-
hub_repo_val = hub_repo.strip() if hub_repo and hub_repo.strip() else None
|
| 4245 |
-
quant = quantization if quantization != "none" else None
|
| 4246 |
-
|
| 4247 |
-
log_lines = []
|
| 4248 |
-
|
| 4249 |
-
def on_log(msg):
|
| 4250 |
-
log_lines.append(msg)
|
| 4251 |
-
|
| 4252 |
-
def on_round(rnd):
|
| 4253 |
-
pass # logged via on_log
|
| 4254 |
-
|
| 4255 |
-
dataset_key = get_source_key_from_label(dataset) if dataset else "builtin"
|
| 4256 |
-
|
| 4257 |
-
runner = TourneyRunner(
|
| 4258 |
-
model_name=model_id,
|
| 4259 |
-
hub_org=hub_org_val,
|
| 4260 |
-
hub_repo=hub_repo_val,
|
| 4261 |
-
dataset_key=dataset_key,
|
| 4262 |
-
quantization=quant,
|
| 4263 |
-
on_log=on_log,
|
| 4264 |
-
on_round=on_round,
|
| 4265 |
-
)
|
| 4266 |
-
|
| 4267 |
-
# Yield progress updates during tournament
|
| 4268 |
-
import threading
|
| 4269 |
-
result_ref = [None]
|
| 4270 |
-
error_ref = [None]
|
| 4271 |
-
|
| 4272 |
-
def _run():
|
| 4273 |
-
try:
|
| 4274 |
-
result_ref[0] = runner.run()
|
| 4275 |
-
except Exception as e:
|
| 4276 |
-
error_ref[0] = e
|
| 4277 |
-
|
| 4278 |
-
thread = threading.Thread(target=_run, daemon=True)
|
| 4279 |
-
thread.start()
|
| 4280 |
-
|
| 4281 |
-
while thread.is_alive():
|
| 4282 |
-
yield (
|
| 4283 |
-
"**Tournament in progress...**",
|
| 4284 |
-
"",
|
| 4285 |
-
"\n".join(log_lines[-100:]),
|
| 4286 |
-
)
|
| 4287 |
-
time.sleep(1.0)
|
| 4288 |
-
|
| 4289 |
-
thread.join(timeout=10)
|
| 4290 |
-
|
| 4291 |
-
if error_ref[0]:
|
| 4292 |
-
yield (
|
| 4293 |
-
f"**Error:** {error_ref[0]}",
|
| 4294 |
-
"",
|
| 4295 |
-
"\n".join(log_lines),
|
| 4296 |
-
)
|
| 4297 |
-
return
|
| 4298 |
-
|
| 4299 |
-
result = result_ref[0]
|
| 4300 |
-
if result and result.winner:
|
| 4301 |
-
bracket_md = render_bracket(result)
|
| 4302 |
-
hub_msg = ""
|
| 4303 |
-
if result.hub_repo:
|
| 4304 |
-
hub_msg = f"\nPushed to [{result.hub_repo}](https://huggingface.co/{result.hub_repo})"
|
| 4305 |
-
yield (
|
| 4306 |
-
f"**Champion: `{result.winner.method}`** "
|
| 4307 |
-
f"(score: {result.winner.score:.4f}){hub_msg}",
|
| 4308 |
-
bracket_md,
|
| 4309 |
-
"\n".join(log_lines),
|
| 4310 |
-
)
|
| 4311 |
-
else:
|
| 4312 |
-
yield (
|
| 4313 |
-
"**Tournament complete** — no winner determined.",
|
| 4314 |
-
"",
|
| 4315 |
-
"\n".join(log_lines),
|
| 4316 |
-
)
|
| 4317 |
|
| 4318 |
# ── Tab 7: Export ────────────────────────────────────────────────���
|
| 4319 |
with gr.Tab("Export", id="export"):
|
|
@@ -4336,7 +4491,94 @@ Download all intermediate data from your last obliteration run as a ZIP archive.
|
|
| 4336 |
outputs=[export_file, export_status],
|
| 4337 |
)
|
| 4338 |
|
| 4339 |
-
# ── Tab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4340 |
with gr.Tab("Leaderboard", id="leaderboard"):
|
| 4341 |
gr.Markdown("""### Community Leaderboard
|
| 4342 |
All benchmark results from **every OBLITERATUS Space** (including duplicated copies) are
|
|
@@ -4562,12 +4804,6 @@ Built on the shoulders of:
|
|
| 4562 |
outputs=[prompt_vol_dd, dataset_info_md],
|
| 4563 |
)
|
| 4564 |
|
| 4565 |
-
# Wire hub repo → live validation
|
| 4566 |
-
hub_repo.change(
|
| 4567 |
-
fn=_validate_hub_repo,
|
| 4568 |
-
inputs=[hub_repo],
|
| 4569 |
-
outputs=[hub_warning_md],
|
| 4570 |
-
)
|
| 4571 |
|
| 4572 |
# Wire benchmark → Chat/A/B cross-tab dropdown updates
|
| 4573 |
bench_btn.click(
|
|
@@ -4616,7 +4852,7 @@ Built on the shoulders of:
|
|
| 4616 |
# may not fire after generator teardown.
|
| 4617 |
obliterate_btn.click(
|
| 4618 |
fn=obliterate,
|
| 4619 |
-
inputs=[model_dd, method_dd,
|
| 4620 |
custom_harmful_tb, custom_harmless_tb] + _adv_controls,
|
| 4621 |
outputs=[status_md, log_box, chat_status, session_model_dd, metrics_md, ab_session_model_dd],
|
| 4622 |
).then(
|
|
|
|
| 57 |
|
| 58 |
import gradio as gr
|
| 59 |
import torch
|
| 60 |
+
from obliteratus import device as dev
|
| 61 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 62 |
|
| 63 |
# ── ZeroGPU support ─────────────────────────────────────────────────
|
|
|
|
| 400 |
return ""
|
| 401 |
|
| 402 |
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
# Push to Hub — dedicated tab backend
|
| 405 |
+
# ---------------------------------------------------------------------------
|
| 406 |
+
|
| 407 |
+
def _generate_model_card(meta: dict) -> str:
|
| 408 |
+
"""Generate a HuggingFace model card README for a session model."""
|
| 409 |
+
model_id = meta.get("model_id", "unknown")
|
| 410 |
+
method = meta.get("method", "unknown")
|
| 411 |
+
source = meta.get("source", "obliterate")
|
| 412 |
+
short_model = model_id.split("/")[-1] if "/" in model_id else model_id
|
| 413 |
+
|
| 414 |
+
metrics_table = ""
|
| 415 |
+
tourney_metrics = meta.get("tourney_metrics")
|
| 416 |
+
if tourney_metrics:
|
| 417 |
+
rows = "\n".join(
|
| 418 |
+
f"| {k.replace('_', ' ').title()} | {v:.4f} |"
|
| 419 |
+
for k, v in tourney_metrics.items() if isinstance(v, (int, float))
|
| 420 |
+
)
|
| 421 |
+
metrics_table = f"\n## Metrics\n\n| Metric | Value |\n|--------|-------|\n{rows}\n"
|
| 422 |
+
|
| 423 |
+
return f"""---
|
| 424 |
+
language: en
|
| 425 |
+
tags:
|
| 426 |
+
- obliteratus
|
| 427 |
+
- abliteration
|
| 428 |
+
- uncensored
|
| 429 |
+
- {source}
|
| 430 |
+
base_model: {model_id}
|
| 431 |
+
---
|
| 432 |
+
|
| 433 |
+
# {short_model}-OBLITERATED
|
| 434 |
+
|
| 435 |
+
This model was abliterated using the **`{method}`** method via
|
| 436 |
+
[OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
|
| 437 |
+
|
| 438 |
+
| Detail | Value |
|
| 439 |
+
|--------|-------|
|
| 440 |
+
| Base model | `{model_id}` |
|
| 441 |
+
| Method | `{method}` |
|
| 442 |
+
| Source | {source} |
|
| 443 |
+
{metrics_table}
|
| 444 |
+
## How to Use
|
| 445 |
+
|
| 446 |
+
```python
|
| 447 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 448 |
+
|
| 449 |
+
model = AutoModelForCausalLM.from_pretrained("{short_model}-OBLITERATED")
|
| 450 |
+
tokenizer = AutoTokenizer.from_pretrained("{short_model}-OBLITERATED")
|
| 451 |
+
|
| 452 |
+
prompt = "Hello, how are you?"
|
| 453 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 454 |
+
outputs = model.generate(**inputs, max_new_tokens=256)
|
| 455 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 456 |
+
```
|
| 457 |
+
|
| 458 |
+
## About OBLITERATUS
|
| 459 |
+
|
| 460 |
+
OBLITERATUS is an open-source tool for removing refusal behavior from language
|
| 461 |
+
models via activation engineering (abliteration). Learn more at
|
| 462 |
+
[github.com/elder-plinius/OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _get_hub_session_info(label: str) -> str:
|
| 467 |
+
"""Return a markdown summary of the selected session model."""
|
| 468 |
+
if not label or label.startswith("("):
|
| 469 |
+
return ""
|
| 470 |
+
meta = _session_models.get(label)
|
| 471 |
+
if not meta:
|
| 472 |
+
return "*Session model not found — try refreshing the list.*"
|
| 473 |
+
lines = [
|
| 474 |
+
f"**Model:** `{meta.get('model_id', 'unknown')}`",
|
| 475 |
+
f"**Method:** `{meta.get('method', 'unknown')}`",
|
| 476 |
+
f"**Source:** {meta.get('source', 'unknown')}",
|
| 477 |
+
f"**Path:** `{meta.get('output_dir', 'N/A')}`",
|
| 478 |
+
]
|
| 479 |
+
score = meta.get("tourney_score")
|
| 480 |
+
if score is not None:
|
| 481 |
+
lines.append(f"**Tourney score:** {score:.4f}")
|
| 482 |
+
return "\n".join(lines)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _auto_hub_repo_id(label: str) -> str:
|
| 486 |
+
"""Generate an auto-filled Hub repo ID for the selected session model."""
|
| 487 |
+
meta = _session_models.get(label)
|
| 488 |
+
if not meta:
|
| 489 |
+
return ""
|
| 490 |
+
model_id = meta.get("model_id", "")
|
| 491 |
+
import re
|
| 492 |
+
short = model_id.split("/")[-1] if "/" in model_id else model_id
|
| 493 |
+
short = re.sub(r"[^a-zA-Z0-9\-.]", "-", short)
|
| 494 |
+
return f"{_HUB_COMMUNITY_ORG}/{short}-OBLITERATED"
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def push_session_to_hub(
|
| 498 |
+
session_label: str,
|
| 499 |
+
hub_repo_id: str,
|
| 500 |
+
hub_token_input: str,
|
| 501 |
+
refine_enabled: bool,
|
| 502 |
+
refine_regularization: float,
|
| 503 |
+
refine_passes: int,
|
| 504 |
+
progress=gr.Progress(),
|
| 505 |
+
):
|
| 506 |
+
"""Push a session model to HuggingFace Hub, with optional refinement."""
|
| 507 |
+
import os
|
| 508 |
+
import re
|
| 509 |
+
|
| 510 |
+
if not session_label or session_label.startswith("("):
|
| 511 |
+
yield "**Error:** Select a session model first.", ""
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
meta = _session_models.get(session_label)
|
| 515 |
+
if not meta:
|
| 516 |
+
yield "**Error:** Session model not found. Try refreshing the list.", ""
|
| 517 |
+
return
|
| 518 |
+
|
| 519 |
+
output_dir = meta.get("output_dir", "")
|
| 520 |
+
if not output_dir or not Path(output_dir).exists():
|
| 521 |
+
yield f"**Error:** Model directory not found: `{output_dir}`", ""
|
| 522 |
+
return
|
| 523 |
+
|
| 524 |
+
# Resolve repo ID
|
| 525 |
+
repo_id = hub_repo_id.strip() if hub_repo_id else ""
|
| 526 |
+
if not repo_id:
|
| 527 |
+
repo_id = _auto_hub_repo_id(session_label)
|
| 528 |
+
if not repo_id:
|
| 529 |
+
yield "**Error:** Could not determine Hub repo ID.", ""
|
| 530 |
+
return
|
| 531 |
+
if not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', repo_id):
|
| 532 |
+
yield "**Error:** Invalid repo format. Use `username/model-name`.", ""
|
| 533 |
+
return
|
| 534 |
+
|
| 535 |
+
# Resolve token
|
| 536 |
+
token = hub_token_input.strip() if hub_token_input else None
|
| 537 |
+
if not token:
|
| 538 |
+
token = os.environ.get("HF_TOKEN") or _HUB_COMMUNITY_TOKEN
|
| 539 |
+
if not token:
|
| 540 |
+
yield (
|
| 541 |
+
"**Error:** No Hub token available. Enter a token above, "
|
| 542 |
+
"or set `HF_TOKEN` / `OBLITERATUS_HUB_TOKEN` as an environment variable.",
|
| 543 |
+
"",
|
| 544 |
+
)
|
| 545 |
+
return
|
| 546 |
+
|
| 547 |
+
# Optional refinement pass
|
| 548 |
+
if refine_enabled and refine_passes > 0:
|
| 549 |
+
progress(0.1, desc="Refining model...")
|
| 550 |
+
yield "Applying refinement passes...", ""
|
| 551 |
+
try:
|
| 552 |
+
from obliteratus.abliterate import AbliterationPipeline
|
| 553 |
+
from obliteratus.prompts import load_dataset_source
|
| 554 |
+
|
| 555 |
+
dataset_key = meta.get("dataset_key", "builtin")
|
| 556 |
+
if dataset_key == "custom":
|
| 557 |
+
dataset_key = "builtin"
|
| 558 |
+
harmful, harmless = load_dataset_source(dataset_key)
|
| 559 |
+
n = min(33, len(harmful), len(harmless))
|
| 560 |
+
|
| 561 |
+
pipeline = AbliterationPipeline(
|
| 562 |
+
model_name=output_dir, # load from saved checkpoint
|
| 563 |
+
output_dir=output_dir,
|
| 564 |
+
device="auto",
|
| 565 |
+
dtype="float16",
|
| 566 |
+
method=meta.get("method", "advanced"),
|
| 567 |
+
regularization=refine_regularization,
|
| 568 |
+
refinement_passes=refine_passes,
|
| 569 |
+
harmful_prompts=harmful[:n],
|
| 570 |
+
harmless_prompts=harmless[:n],
|
| 571 |
+
)
|
| 572 |
+
pipeline.run()
|
| 573 |
+
except Exception as e:
|
| 574 |
+
yield f"**Refinement failed:** {e}", ""
|
| 575 |
+
return
|
| 576 |
+
|
| 577 |
+
# Generate model card
|
| 578 |
+
progress(0.5, desc="Generating model card...")
|
| 579 |
+
yield f"Generating model card and uploading to `{repo_id}`...", ""
|
| 580 |
+
card_content = _generate_model_card(meta)
|
| 581 |
+
card_path = Path(output_dir) / "README.md"
|
| 582 |
+
card_path.write_text(card_content)
|
| 583 |
+
|
| 584 |
+
# Upload to Hub
|
| 585 |
+
progress(0.6, desc="Uploading to Hub...")
|
| 586 |
+
try:
|
| 587 |
+
from huggingface_hub import HfApi
|
| 588 |
+
api = HfApi(token=token)
|
| 589 |
+
api.create_repo(repo_id, exist_ok=True)
|
| 590 |
+
|
| 591 |
+
method = meta.get("method", "unknown")
|
| 592 |
+
model_id = meta.get("model_id", "unknown")
|
| 593 |
+
api.upload_folder(
|
| 594 |
+
folder_path=output_dir,
|
| 595 |
+
repo_id=repo_id,
|
| 596 |
+
commit_message=f"OBLITERATUS: {method} on {model_id}",
|
| 597 |
+
)
|
| 598 |
+
except Exception as e:
|
| 599 |
+
yield f"**Upload failed:** {e}", ""
|
| 600 |
+
return
|
| 601 |
+
|
| 602 |
+
progress(1.0, desc="Done!")
|
| 603 |
+
hub_url = f"https://huggingface.co/{repo_id}"
|
| 604 |
+
yield (
|
| 605 |
+
f"**Pushed successfully to [{repo_id}]({hub_url})**",
|
| 606 |
+
f"[Open on HuggingFace Hub]({hub_url})",
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
PROMPT_VOLUMES = {
|
| 611 |
"33 (fast)": 33,
|
| 612 |
"66 (better signal)": 66,
|
|
|
|
| 655 |
# ---------------------------------------------------------------------------
|
| 656 |
|
| 657 |
def _clear_gpu():
|
| 658 |
+
"""Free GPU/accelerator memory. Resilient to device errors."""
|
| 659 |
with _lock:
|
| 660 |
_state["model"] = None
|
| 661 |
_state["tokenizer"] = None
|
| 662 |
+
dev.free_gpu_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
|
| 665 |
def _install_steering_hooks(model, steering_meta: dict) -> int:
|
|
|
|
| 783 |
# ---------------------------------------------------------------------------
|
| 784 |
|
| 785 |
def _get_vram_html() -> str:
|
| 786 |
+
"""Return an HTML snippet showing GPU/accelerator memory usage as a styled bar."""
|
| 787 |
+
if not dev.is_gpu_available():
|
| 788 |
return (
|
| 789 |
'<div style="text-align:center;color:#4a5568;font-size:0.72rem;'
|
| 790 |
'letter-spacing:1px;margin-top:6px;">CPU ONLY — NO GPU DETECTED</div>'
|
| 791 |
)
|
| 792 |
try:
|
| 793 |
+
mem = dev.get_memory_info()
|
| 794 |
+
used = mem.used_gb
|
| 795 |
+
total = mem.total_gb
|
| 796 |
pct = (used / total * 100) if total > 0 else 0
|
| 797 |
# Color shifts from green → yellow → red
|
| 798 |
if pct < 50:
|
|
|
|
| 801 |
bar_color = "#ffcc00"
|
| 802 |
else:
|
| 803 |
bar_color = "#ff003c"
|
| 804 |
+
device_name = mem.device_name
|
| 805 |
+
reserved_html = (
|
| 806 |
+
f'<span style="color:#4a5568;">reserved: {mem.reserved_gb:.1f} GB</span>'
|
| 807 |
+
if mem.reserved_gb > 0
|
| 808 |
+
else f'<span style="color:#4a5568;">unified memory</span>'
|
| 809 |
+
)
|
| 810 |
return (
|
| 811 |
f'<div style="margin:6px auto 0;max-width:480px;">'
|
| 812 |
f'<div style="display:flex;justify-content:space-between;font-size:0.68rem;'
|
| 813 |
f'color:#4a5568;letter-spacing:1px;margin-bottom:2px;">'
|
| 814 |
+
f'<span>{device_name}</span>'
|
| 815 |
f'<span>{used:.1f} / {total:.1f} GB ({pct:.0f}%)</span></div>'
|
| 816 |
f'<div style="background:#0a0a0f;border:1px solid #1a1f2e;border-radius:3px;'
|
| 817 |
f'height:10px;overflow:hidden;">'
|
|
|
|
| 819 |
f'box-shadow:0 0 6px {bar_color};transition:width 0.5s ease;"></div></div>'
|
| 820 |
f'<div style="display:flex;justify-content:space-between;font-size:0.6rem;'
|
| 821 |
f'color:#333;margin-top:1px;">'
|
| 822 |
+
f'{reserved_html}</div>'
|
| 823 |
f'</div>'
|
| 824 |
)
|
| 825 |
except Exception:
|
| 826 |
+
return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">Memory: unavailable</div>'
|
| 827 |
|
| 828 |
|
| 829 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1266 |
pass
|
| 1267 |
pipeline_ref[0] = None
|
| 1268 |
gc.collect()
|
| 1269 |
+
dev.empty_cache()
|
|
|
|
| 1270 |
|
| 1271 |
yield (
|
| 1272 |
f"**{method_key} complete** ({mi + 1}/{len(methods_to_test)}) \u2014 {_bench_elapsed()}",
|
|
|
|
| 1616 |
pass
|
| 1617 |
pipeline_ref[0] = None
|
| 1618 |
gc.collect()
|
| 1619 |
+
dev.empty_cache()
|
|
|
|
| 1620 |
|
| 1621 |
yield (
|
| 1622 |
f"**{model_id} complete** ({mi + 1}/{len(model_choices)}) \u2014 {_mm_elapsed()}",
|
|
|
|
| 1715 |
|
| 1716 |
@spaces.GPU(duration=300)
|
| 1717 |
def obliterate(model_choice: str, method_choice: str,
|
|
|
|
| 1718 |
prompt_volume_choice: str, dataset_source_choice: str,
|
| 1719 |
custom_harmful: str, custom_harmless: str,
|
| 1720 |
# Advanced params (sliders)
|
|
|
|
| 1747 |
model_id = MODELS.get(model_choice, model_choice)
|
| 1748 |
is_preset = model_choice in MODELS
|
| 1749 |
method = METHODS.get(method_choice, "advanced")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1750 |
prompt_volume = PROMPT_VOLUMES.get(prompt_volume_choice, 33)
|
| 1751 |
|
| 1752 |
# Resolve "adaptive" → telemetry-recommended method for this model
|
|
|
|
| 1794 |
)
|
| 1795 |
return
|
| 1796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1797 |
# Resolve dataset source — custom prompts override the dropdown
|
| 1798 |
use_custom = custom_harmful and custom_harmful.strip()
|
| 1799 |
dataset_key = get_source_key_from_label(dataset_source_choice) if dataset_source_choice else "builtin"
|
|
|
|
| 1867 |
output_dir=save_dir,
|
| 1868 |
device="auto",
|
| 1869 |
dtype="float16",
|
|
|
|
|
|
|
|
|
|
| 1870 |
quantization=quantization,
|
| 1871 |
trust_remote_code=is_preset,
|
| 1872 |
harmful_prompts=harmful_all[:n],
|
|
|
|
| 1884 |
device="auto",
|
| 1885 |
dtype="float16",
|
| 1886 |
method=method,
|
|
|
|
|
|
|
|
|
|
| 1887 |
quantization=quantization,
|
| 1888 |
trust_remote_code=is_preset,
|
| 1889 |
harmful_prompts=harmful_all[:n],
|
|
|
|
| 1936 |
log_lines.append(f"Dataset: {source_label}")
|
| 1937 |
vol_label = "all" if prompt_volume == -1 else str(prompt_volume)
|
| 1938 |
log_lines.append(f"Prompt volume: {vol_label} pairs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1939 |
if quantization:
|
| 1940 |
log_lines.append(f"Quantization: {quantization} (auto-detected for GPU fit)")
|
| 1941 |
log_lines.append("")
|
|
|
|
| 2274 |
_needs_reload = model is None or tokenizer is None
|
| 2275 |
if not _needs_reload:
|
| 2276 |
try:
|
| 2277 |
+
model_dev = next(model.parameters()).device
|
| 2278 |
+
if model_dev.type == "meta":
|
| 2279 |
_needs_reload = True
|
| 2280 |
+
elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
|
| 2281 |
+
model.to(dev.get_device())
|
| 2282 |
except Exception:
|
| 2283 |
_needs_reload = True
|
| 2284 |
|
|
|
|
| 2708 |
_needs_reload = abliterated_model is None or tokenizer is None
|
| 2709 |
if not _needs_reload:
|
| 2710 |
try:
|
| 2711 |
+
model_dev = next(abliterated_model.parameters()).device
|
| 2712 |
+
if model_dev.type == "meta":
|
| 2713 |
_needs_reload = True
|
| 2714 |
+
elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
|
| 2715 |
+
abliterated_model.to(dev.get_device())
|
| 2716 |
except Exception:
|
| 2717 |
_needs_reload = True
|
| 2718 |
|
|
|
|
| 2845 |
abl_device = next(abliterated_model.parameters()).device
|
| 2846 |
abliterated_model.to("cpu")
|
| 2847 |
gc.collect()
|
| 2848 |
+
dev.empty_cache()
|
|
|
|
| 2849 |
|
| 2850 |
model_id = MODELS.get(model_name, model_name)
|
| 2851 |
# Only trust remote code for known preset models, not arbitrary user-supplied IDs
|
|
|
|
| 2897 |
# Free the original model
|
| 2898 |
del original_model
|
| 2899 |
gc.collect()
|
| 2900 |
+
dev.empty_cache()
|
|
|
|
| 2901 |
|
| 2902 |
except Exception as e:
|
| 2903 |
original_response = f"*Could not load original model for comparison: {e}*"
|
|
|
|
| 2906 |
# Use torch.device("cuda") rather than the captured abl_device, since
|
| 2907 |
# on ZeroGPU the original device reference may point to a stale context.
|
| 2908 |
try:
|
| 2909 |
+
restore_device = torch.device(dev.get_device()) if dev.is_gpu_available() else abl_device
|
| 2910 |
abliterated_model.to(restore_device)
|
| 2911 |
except Exception:
|
| 2912 |
pass # If GPU restore fails, model stays on CPU (still usable)
|
|
|
|
| 3024 |
|
| 3025 |
# Cleanup between runs
|
| 3026 |
gc.collect()
|
| 3027 |
+
dev.empty_cache()
|
|
|
|
| 3028 |
|
| 3029 |
# Generate dose-response curve
|
| 3030 |
gallery = None
|
|
|
|
| 3116 |
return "\n".join(lines)
|
| 3117 |
|
| 3118 |
|
| 3119 |
+
# ---------------------------------------------------------------------------
|
| 3120 |
+
# Tournament
|
| 3121 |
+
# ---------------------------------------------------------------------------
|
| 3122 |
+
|
| 3123 |
+
@spaces.GPU(duration=300)
|
| 3124 |
+
def run_tourney(model_choice, dataset, quantization):
|
| 3125 |
+
"""Run an elimination tournament across all abliteration methods.
|
| 3126 |
+
|
| 3127 |
+
On ZeroGPU Spaces the @spaces.GPU decorator allocates a GPU for the
|
| 3128 |
+
duration of the tournament (up to 5 minutes).
|
| 3129 |
+
"""
|
| 3130 |
+
if not model_choice or not model_choice.strip():
|
| 3131 |
+
yield "**Error:** Select a model first.", "", ""
|
| 3132 |
+
return
|
| 3133 |
+
|
| 3134 |
+
from obliteratus.tourney import TourneyRunner, render_bracket
|
| 3135 |
+
|
| 3136 |
+
# Resolve display label → HuggingFace model ID
|
| 3137 |
+
model_id = model_choice.strip()
|
| 3138 |
+
if model_id in MODELS:
|
| 3139 |
+
model_id = MODELS[model_id]
|
| 3140 |
+
|
| 3141 |
+
quant = quantization if quantization != "none" else None
|
| 3142 |
+
|
| 3143 |
+
log_lines = []
|
| 3144 |
+
|
| 3145 |
+
def on_log(msg):
|
| 3146 |
+
log_lines.append(msg)
|
| 3147 |
+
|
| 3148 |
+
def on_round(rnd):
|
| 3149 |
+
pass # logged via on_log
|
| 3150 |
+
|
| 3151 |
+
dataset_key = get_source_key_from_label(dataset) if dataset else "builtin"
|
| 3152 |
+
|
| 3153 |
+
runner = TourneyRunner(
|
| 3154 |
+
model_name=model_id,
|
| 3155 |
+
hub_org=None,
|
| 3156 |
+
hub_repo=None,
|
| 3157 |
+
dataset_key=dataset_key,
|
| 3158 |
+
quantization=quant,
|
| 3159 |
+
on_log=on_log,
|
| 3160 |
+
on_round=on_round,
|
| 3161 |
+
)
|
| 3162 |
+
|
| 3163 |
+
# Run tournament in a background thread so we can yield progress
|
| 3164 |
+
import threading
|
| 3165 |
+
result_ref = [None]
|
| 3166 |
+
error_ref = [None]
|
| 3167 |
+
|
| 3168 |
+
def _run():
|
| 3169 |
+
try:
|
| 3170 |
+
result_ref[0] = runner.run()
|
| 3171 |
+
except Exception as e:
|
| 3172 |
+
error_ref[0] = e
|
| 3173 |
+
|
| 3174 |
+
thread = threading.Thread(target=_run, daemon=True)
|
| 3175 |
+
thread.start()
|
| 3176 |
+
|
| 3177 |
+
while thread.is_alive():
|
| 3178 |
+
yield (
|
| 3179 |
+
"**Tournament in progress...**",
|
| 3180 |
+
"",
|
| 3181 |
+
"\n".join(log_lines[-100:]),
|
| 3182 |
+
)
|
| 3183 |
+
time.sleep(1.0)
|
| 3184 |
+
|
| 3185 |
+
thread.join(timeout=10)
|
| 3186 |
+
|
| 3187 |
+
if error_ref[0]:
|
| 3188 |
+
yield (
|
| 3189 |
+
f"**Error:** {error_ref[0]}",
|
| 3190 |
+
"",
|
| 3191 |
+
"\n".join(log_lines),
|
| 3192 |
+
)
|
| 3193 |
+
return
|
| 3194 |
+
|
| 3195 |
+
result = result_ref[0]
|
| 3196 |
+
if result and result.winner:
|
| 3197 |
+
bracket_md = render_bracket(result)
|
| 3198 |
+
# Register winner in session models for Push to Hub tab
|
| 3199 |
+
if result.winner.output_dir:
|
| 3200 |
+
_ts = datetime.now().strftime("%H:%M")
|
| 3201 |
+
_short = model_id.split("/")[-1] if "/" in model_id else model_id
|
| 3202 |
+
_label = f"tourney winner ({result.winner.method}) on {_short} ({_ts})"
|
| 3203 |
+
with _lock:
|
| 3204 |
+
_session_models[_label] = {
|
| 3205 |
+
"model_id": model_id,
|
| 3206 |
+
"model_choice": model_choice,
|
| 3207 |
+
"method": result.winner.method,
|
| 3208 |
+
"dataset_key": dataset_key,
|
| 3209 |
+
"prompt_volume": 0,
|
| 3210 |
+
"output_dir": result.winner.output_dir,
|
| 3211 |
+
"source": "tourney",
|
| 3212 |
+
"tourney_score": result.winner.score,
|
| 3213 |
+
"tourney_metrics": result.winner.metrics,
|
| 3214 |
+
}
|
| 3215 |
+
yield (
|
| 3216 |
+
f"**Champion: `{result.winner.method}`** "
|
| 3217 |
+
f"(score: {result.winner.score:.4f})\n"
|
| 3218 |
+
f"Push it to HuggingFace Hub from the **Push to Hub** tab.",
|
| 3219 |
+
bracket_md,
|
| 3220 |
+
"\n".join(log_lines),
|
| 3221 |
+
)
|
| 3222 |
+
else:
|
| 3223 |
+
yield (
|
| 3224 |
+
"**Tournament complete** — no winner determined.",
|
| 3225 |
+
"",
|
| 3226 |
+
"\n".join(log_lines),
|
| 3227 |
+
)
|
| 3228 |
+
|
| 3229 |
+
|
| 3230 |
# ---------------------------------------------------------------------------
|
| 3231 |
# Export Research Artifacts
|
| 3232 |
# ---------------------------------------------------------------------------
|
|
|
|
| 3787 |
lines=5,
|
| 3788 |
)
|
| 3789 |
|
| 3790 |
+
gr.Markdown(
|
| 3791 |
+
"*After obliterating, push your model to HuggingFace Hub from the **Push to Hub** tab.*",
|
| 3792 |
+
elem_classes=["hub-hint"],
|
| 3793 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3794 |
|
| 3795 |
# ── Advanced Settings (auto-populated from method preset) ────
|
| 3796 |
_defaults = _get_preset_defaults("advanced (recommended)")
|
|
|
|
| 4422 |
with gr.Tab("Tourney", id="tourney"):
|
| 4423 |
gr.Markdown("""### March Madness Tournament
|
| 4424 |
Pit **all abliteration methods** against each other in elimination rounds.
|
| 4425 |
+
The winner is saved locally — push it to HuggingFace Hub from the **Push to Hub** tab.
|
| 4426 |
|
| 4427 |
**Round 1 — Qualifiers:** All methods, reduced prompts. Bottom half eliminated.
|
| 4428 |
**Round 2 — Semifinals:** Survivors, full prompts. Bottom half eliminated.
|
| 4429 |
**Round 3 — Finals:** Top contenders, maximum prompts. Champion crowned.
|
| 4430 |
""")
|
| 4431 |
+
tourney_model_dd = gr.Dropdown(
|
| 4432 |
+
choices=list(MODELS.keys()),
|
| 4433 |
+
value="Alibaba (Qwen) / Qwen3-4B",
|
| 4434 |
+
label="Target Model",
|
| 4435 |
+
info="Select a model to tournament-abliterate",
|
| 4436 |
+
allow_custom_value=True,
|
| 4437 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4438 |
|
| 4439 |
with gr.Accordion("Advanced Settings", open=False):
|
| 4440 |
with gr.Row():
|
|
|
|
| 4463 |
interactive=False,
|
| 4464 |
)
|
| 4465 |
|
| 4466 |
+
tourney_btn.click(
|
| 4467 |
+
fn=run_tourney,
|
| 4468 |
+
inputs=[tourney_model_dd,
|
| 4469 |
tourney_dataset_dd, tourney_quant_dd],
|
| 4470 |
outputs=[tourney_status, tourney_bracket, tourney_log],
|
| 4471 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4472 |
|
| 4473 |
# ── Tab 7: Export ────────────────────────────────────────────────���
|
| 4474 |
with gr.Tab("Export", id="export"):
|
|
|
|
| 4491 |
outputs=[export_file, export_status],
|
| 4492 |
)
|
| 4493 |
|
| 4494 |
+
# ── Tab: Push to Hub ──────────────────────────────────────────────
|
| 4495 |
+
with gr.Tab("Push to Hub", id="push_hub"):
|
| 4496 |
+
gr.Markdown("""### Push to HuggingFace Hub
|
| 4497 |
+
Select any session model from your Obliterate, Benchmark, or Tourney runs,
|
| 4498 |
+
optionally apply a quick refinement pass, then push to HuggingFace Hub
|
| 4499 |
+
with the **-OBLITERATED** tag.
|
| 4500 |
+
""")
|
| 4501 |
+
|
| 4502 |
+
with gr.Row():
|
| 4503 |
+
with gr.Column(scale=2):
|
| 4504 |
+
push_session_dd = gr.Dropdown(
|
| 4505 |
+
choices=_get_session_model_choices(),
|
| 4506 |
+
label="Session Model",
|
| 4507 |
+
info="Pick a model from any tab's output",
|
| 4508 |
+
)
|
| 4509 |
+
push_refresh_btn = gr.Button("Refresh List", variant="secondary", size="sm")
|
| 4510 |
+
push_model_info = gr.Markdown("")
|
| 4511 |
+
|
| 4512 |
+
with gr.Column(scale=1):
|
| 4513 |
+
push_repo_id = gr.Textbox(
|
| 4514 |
+
label="Hub Repo ID",
|
| 4515 |
+
placeholder="auto-filled, or type your own",
|
| 4516 |
+
info="e.g. my-org/my-model-OBLITERATED",
|
| 4517 |
+
)
|
| 4518 |
+
push_token = gr.Textbox(
|
| 4519 |
+
label="HF Token (optional)",
|
| 4520 |
+
placeholder="hf_...",
|
| 4521 |
+
type="password",
|
| 4522 |
+
info="Leave blank to use HF_TOKEN env var or community token",
|
| 4523 |
+
)
|
| 4524 |
+
push_repo_warning = gr.Markdown("")
|
| 4525 |
+
|
| 4526 |
+
with gr.Accordion("Quick Refiner (optional)", open=False):
|
| 4527 |
+
gr.Markdown(
|
| 4528 |
+
"*Optionally apply extra refinement passes to your model before pushing. "
|
| 4529 |
+
"This re-runs the abliteration pipeline with adjusted regularization.*"
|
| 4530 |
+
)
|
| 4531 |
+
with gr.Row():
|
| 4532 |
+
push_refine_reg = gr.Slider(
|
| 4533 |
+
0.0, 1.0, value=0.1, step=0.05,
|
| 4534 |
+
label="Regularization",
|
| 4535 |
+
info="Weight preservation (0 = full removal, 1 = no change)",
|
| 4536 |
+
)
|
| 4537 |
+
push_refine_passes = gr.Slider(
|
| 4538 |
+
0, 3, value=0, step=1,
|
| 4539 |
+
label="Extra Refinement Passes",
|
| 4540 |
+
info="0 = skip refinement, 1-3 = apply additional passes",
|
| 4541 |
+
)
|
| 4542 |
+
push_refine_enabled = gr.Checkbox(
|
| 4543 |
+
label="Apply refinement before pushing",
|
| 4544 |
+
value=False,
|
| 4545 |
+
)
|
| 4546 |
+
|
| 4547 |
+
push_btn = gr.Button(
|
| 4548 |
+
"Push to Hub",
|
| 4549 |
+
variant="primary",
|
| 4550 |
+
size="lg",
|
| 4551 |
+
)
|
| 4552 |
+
push_status = gr.Markdown("")
|
| 4553 |
+
push_link = gr.Markdown("")
|
| 4554 |
+
|
| 4555 |
+
# -- Event wiring (inline since components are scoped to this tab) --
|
| 4556 |
+
|
| 4557 |
+
push_refresh_btn.click(
|
| 4558 |
+
fn=lambda: gr.update(choices=_get_session_model_choices()),
|
| 4559 |
+
outputs=[push_session_dd],
|
| 4560 |
+
)
|
| 4561 |
+
|
| 4562 |
+
push_session_dd.change(
|
| 4563 |
+
fn=lambda label: (_get_hub_session_info(label), _auto_hub_repo_id(label)),
|
| 4564 |
+
inputs=[push_session_dd],
|
| 4565 |
+
outputs=[push_model_info, push_repo_id],
|
| 4566 |
+
)
|
| 4567 |
+
|
| 4568 |
+
push_repo_id.change(
|
| 4569 |
+
fn=_validate_hub_repo,
|
| 4570 |
+
inputs=[push_repo_id],
|
| 4571 |
+
outputs=[push_repo_warning],
|
| 4572 |
+
)
|
| 4573 |
+
|
| 4574 |
+
push_btn.click(
|
| 4575 |
+
fn=push_session_to_hub,
|
| 4576 |
+
inputs=[push_session_dd, push_repo_id, push_token,
|
| 4577 |
+
push_refine_enabled, push_refine_reg, push_refine_passes],
|
| 4578 |
+
outputs=[push_status, push_link],
|
| 4579 |
+
)
|
| 4580 |
+
|
| 4581 |
+
# ── Tab: Leaderboard ────────────────────────────────────────────
|
| 4582 |
with gr.Tab("Leaderboard", id="leaderboard"):
|
| 4583 |
gr.Markdown("""### Community Leaderboard
|
| 4584 |
All benchmark results from **every OBLITERATUS Space** (including duplicated copies) are
|
|
|
|
| 4804 |
outputs=[prompt_vol_dd, dataset_info_md],
|
| 4805 |
)
|
| 4806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4807 |
|
| 4808 |
# Wire benchmark → Chat/A/B cross-tab dropdown updates
|
| 4809 |
bench_btn.click(
|
|
|
|
| 4852 |
# may not fire after generator teardown.
|
| 4853 |
obliterate_btn.click(
|
| 4854 |
fn=obliterate,
|
| 4855 |
+
inputs=[model_dd, method_dd, prompt_vol_dd, dataset_dd,
|
| 4856 |
custom_harmful_tb, custom_harmless_tb] + _adv_controls,
|
| 4857 |
outputs=[status_md, log_box, chat_status, session_model_dd, metrics_md, ab_session_model_dd],
|
| 4858 |
).then(
|
obliteratus/.DS_Store
CHANGED
|
Binary files a/obliteratus/.DS_Store and b/obliteratus/.DS_Store differ
|
|
|
obliteratus/abliterate.py
CHANGED
|
@@ -33,11 +33,12 @@ from typing import Any, Callable
|
|
| 33 |
import torch
|
| 34 |
import torch.nn as nn
|
| 35 |
|
|
|
|
|
|
|
| 36 |
# Reduce CUDA memory fragmentation for large models. Must be set before any
|
| 37 |
# CUDA allocations, so we do it at import time. This is the PyTorch-recommended
|
| 38 |
# fix for "reserved but unallocated" memory issues.
|
| 39 |
-
|
| 40 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 41 |
|
| 42 |
from obliteratus.models.loader import ModelHandle, load_model # noqa: E402
|
| 43 |
from obliteratus.strategies.utils import ( # noqa: E402
|
|
@@ -788,16 +789,8 @@ class AbliterationPipeline:
|
|
| 788 |
|
| 789 |
@staticmethod
|
| 790 |
def _free_gpu_memory():
|
| 791 |
-
"""Release unused GPU memory between pipeline stages."""
|
| 792 |
-
|
| 793 |
-
gc.collect()
|
| 794 |
-
if torch.cuda.is_available():
|
| 795 |
-
try:
|
| 796 |
-
torch.cuda.empty_cache()
|
| 797 |
-
except Exception:
|
| 798 |
-
# CUDA may be in an error state after illegal memory access;
|
| 799 |
-
# swallow so we don't cascade into every subsequent stage.
|
| 800 |
-
pass
|
| 801 |
|
| 802 |
@staticmethod
|
| 803 |
def _get_model_device(model: nn.Module) -> torch.device:
|
|
@@ -1404,12 +1397,8 @@ class AbliterationPipeline:
|
|
| 1404 |
max_length = self.max_seq_length
|
| 1405 |
else:
|
| 1406 |
max_length = 384 if collect_multi_pos else 256
|
| 1407 |
-
free_gb =
|
| 1408 |
-
if
|
| 1409 |
-
free_gb = sum(
|
| 1410 |
-
torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
| 1411 |
-
for i in range(torch.cuda.device_count())
|
| 1412 |
-
)
|
| 1413 |
if self.max_seq_length is None and free_gb < 2.0:
|
| 1414 |
max_length = 64
|
| 1415 |
self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
|
@@ -1999,22 +1988,22 @@ class AbliterationPipeline:
|
|
| 1999 |
# Memory-aware cap: SAE encoder+decoder use
|
| 2000 |
# 2 * hidden * (expansion * hidden) * 4 bytes
|
| 2001 |
sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
|
| 2002 |
-
if
|
| 2003 |
try:
|
| 2004 |
-
free_mb =
|
| 2005 |
# Leave 512 MB headroom for other ops
|
| 2006 |
while sae_mem_mb > (free_mb - 512) and sae_expansion > 1:
|
| 2007 |
sae_expansion //= 2
|
| 2008 |
sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
|
| 2009 |
except Exception:
|
| 2010 |
pass # Fallback to hidden_dim-based heuristic
|
| 2011 |
-
# Use GPU when enough headroom exists (SAE is small relative to model)
|
| 2012 |
sae_device = "cpu"
|
| 2013 |
-
if
|
| 2014 |
try:
|
| 2015 |
-
sae_free_mb =
|
| 2016 |
if sae_free_mb > sae_mem_mb + 1024:
|
| 2017 |
-
sae_device =
|
| 2018 |
except Exception:
|
| 2019 |
pass
|
| 2020 |
sae = train_sae(
|
|
@@ -4155,7 +4144,8 @@ class AbliterationPipeline:
|
|
| 4155 |
continue
|
| 4156 |
original_norm = saved_norms[param_name]
|
| 4157 |
if original_norm > 0:
|
| 4158 |
-
|
|
|
|
| 4159 |
new_norm = data.norm().item()
|
| 4160 |
if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
|
| 4161 |
continue # Skip — weight is degenerate after projection
|
|
@@ -4165,7 +4155,12 @@ class AbliterationPipeline:
|
|
| 4165 |
# layers. Uncapped amplification destroys coherence.
|
| 4166 |
if ratio > _MAX_NORM_RATIO:
|
| 4167 |
ratio = _MAX_NORM_RATIO
|
| 4168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4169 |
|
| 4170 |
@staticmethod
|
| 4171 |
def _project_out_advanced(
|
|
@@ -5371,16 +5366,19 @@ class AbliterationPipeline:
|
|
| 5371 |
unique_ratio = len(set(words)) / len(words)
|
| 5372 |
if unique_ratio > 0.2:
|
| 5373 |
coherent_count += 1
|
| 5374 |
-
except torch.cuda.OutOfMemoryError:
|
| 5375 |
-
self._free_gpu_memory()
|
| 5376 |
-
self.log(" Skipping generation tests (CUDA out of memory — model too large for KV cache)")
|
| 5377 |
-
generation_failed = True
|
| 5378 |
except (RuntimeError, Exception) as e:
|
| 5379 |
-
|
| 5380 |
-
if "CUDA" in err_msg or "illegal" in err_msg.lower():
|
| 5381 |
self._free_gpu_memory()
|
| 5382 |
-
self.log(
|
| 5383 |
generation_failed = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5384 |
else:
|
| 5385 |
raise
|
| 5386 |
|
|
@@ -5523,18 +5521,21 @@ class AbliterationPipeline:
|
|
| 5523 |
|
| 5524 |
del inputs, outputs
|
| 5525 |
self._free_gpu_memory()
|
| 5526 |
-
except torch.cuda.OutOfMemoryError:
|
| 5527 |
-
self._free_gpu_memory()
|
| 5528 |
-
self.log(f" [batch {batch_start+1}-{batch_end}] CUDA OOM — stopping")
|
| 5529 |
-
self.log(" Skipping remaining refusal tests (CUDA out of memory)")
|
| 5530 |
-
oom_break = True
|
| 5531 |
except (RuntimeError, Exception) as e:
|
| 5532 |
-
|
| 5533 |
-
if "CUDA" in err_msg or "illegal" in err_msg.lower():
|
| 5534 |
self._free_gpu_memory()
|
| 5535 |
-
self.log(f" [batch {batch_start+1}-{batch_end}]
|
| 5536 |
-
self.log(
|
| 5537 |
oom_break = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5538 |
else:
|
| 5539 |
raise
|
| 5540 |
|
|
|
|
| 33 |
import torch
|
| 34 |
import torch.nn as nn
|
| 35 |
|
| 36 |
+
from obliteratus import device as dev # noqa: E402 — must import before CUDA setup
|
| 37 |
+
|
| 38 |
# Reduce CUDA memory fragmentation for large models. Must be set before any
|
| 39 |
# CUDA allocations, so we do it at import time. This is the PyTorch-recommended
|
| 40 |
# fix for "reserved but unallocated" memory issues.
|
| 41 |
+
dev.configure_cuda_alloc()
|
|
|
|
| 42 |
|
| 43 |
from obliteratus.models.loader import ModelHandle, load_model # noqa: E402
|
| 44 |
from obliteratus.strategies.utils import ( # noqa: E402
|
|
|
|
| 789 |
|
| 790 |
@staticmethod
|
| 791 |
def _free_gpu_memory():
|
| 792 |
+
"""Release unused GPU/accelerator memory between pipeline stages."""
|
| 793 |
+
dev.free_gpu_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
|
| 795 |
@staticmethod
|
| 796 |
def _get_model_device(model: nn.Module) -> torch.device:
|
|
|
|
| 1397 |
max_length = self.max_seq_length
|
| 1398 |
else:
|
| 1399 |
max_length = 384 if collect_multi_pos else 256
|
| 1400 |
+
free_gb = dev.get_total_free_gb()
|
| 1401 |
+
if dev.is_gpu_available():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1402 |
if self.max_seq_length is None and free_gb < 2.0:
|
| 1403 |
max_length = 64
|
| 1404 |
self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
|
|
|
| 1988 |
# Memory-aware cap: SAE encoder+decoder use
|
| 1989 |
# 2 * hidden * (expansion * hidden) * 4 bytes
|
| 1990 |
sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
|
| 1991 |
+
if dev.is_gpu_available():
|
| 1992 |
try:
|
| 1993 |
+
free_mb = dev.get_total_free_gb() * 1024
|
| 1994 |
# Leave 512 MB headroom for other ops
|
| 1995 |
while sae_mem_mb > (free_mb - 512) and sae_expansion > 1:
|
| 1996 |
sae_expansion //= 2
|
| 1997 |
sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
|
| 1998 |
except Exception:
|
| 1999 |
pass # Fallback to hidden_dim-based heuristic
|
| 2000 |
+
# Use GPU/MPS when enough headroom exists (SAE is small relative to model)
|
| 2001 |
sae_device = "cpu"
|
| 2002 |
+
if dev.is_gpu_available():
|
| 2003 |
try:
|
| 2004 |
+
sae_free_mb = dev.get_total_free_gb() * 1024
|
| 2005 |
if sae_free_mb > sae_mem_mb + 1024:
|
| 2006 |
+
sae_device = dev.get_device()
|
| 2007 |
except Exception:
|
| 2008 |
pass
|
| 2009 |
sae = train_sae(
|
|
|
|
| 4144 |
continue
|
| 4145 |
original_norm = saved_norms[param_name]
|
| 4146 |
if original_norm > 0:
|
| 4147 |
+
needs_cast = not param.data.is_floating_point()
|
| 4148 |
+
data = param.data.float() if needs_cast else param.data
|
| 4149 |
new_norm = data.norm().item()
|
| 4150 |
if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
|
| 4151 |
continue # Skip — weight is degenerate after projection
|
|
|
|
| 4155 |
# layers. Uncapped amplification destroys coherence.
|
| 4156 |
if ratio > _MAX_NORM_RATIO:
|
| 4157 |
ratio = _MAX_NORM_RATIO
|
| 4158 |
+
if needs_cast:
|
| 4159 |
+
# Non-float dtypes (e.g. uint8) can't mul_ by a float
|
| 4160 |
+
# scalar in-place — rescale in float then cast back.
|
| 4161 |
+
param.data.copy_(data.mul_(ratio).to(param.data.dtype))
|
| 4162 |
+
else:
|
| 4163 |
+
param.data.mul_(ratio)
|
| 4164 |
|
| 4165 |
@staticmethod
|
| 4166 |
def _project_out_advanced(
|
|
|
|
| 5366 |
unique_ratio = len(set(words)) / len(words)
|
| 5367 |
if unique_ratio > 0.2:
|
| 5368 |
coherent_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5369 |
except (RuntimeError, Exception) as e:
|
| 5370 |
+
if dev.is_oom_error(e):
|
|
|
|
| 5371 |
self._free_gpu_memory()
|
| 5372 |
+
self.log(" Skipping generation tests (out of memory — model too large for KV cache)")
|
| 5373 |
generation_failed = True
|
| 5374 |
+
elif isinstance(e, RuntimeError):
|
| 5375 |
+
err_msg = str(e)
|
| 5376 |
+
if "CUDA" in err_msg or "MPS" in err_msg or "illegal" in err_msg.lower():
|
| 5377 |
+
self._free_gpu_memory()
|
| 5378 |
+
self.log(f" Skipping generation tests (device error: {err_msg[:120]})")
|
| 5379 |
+
generation_failed = True
|
| 5380 |
+
else:
|
| 5381 |
+
raise
|
| 5382 |
else:
|
| 5383 |
raise
|
| 5384 |
|
|
|
|
| 5521 |
|
| 5522 |
del inputs, outputs
|
| 5523 |
self._free_gpu_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5524 |
except (RuntimeError, Exception) as e:
|
| 5525 |
+
if dev.is_oom_error(e):
|
|
|
|
| 5526 |
self._free_gpu_memory()
|
| 5527 |
+
self.log(f" [batch {batch_start+1}-{batch_end}] OOM — stopping")
|
| 5528 |
+
self.log(" Skipping remaining refusal tests (out of memory)")
|
| 5529 |
oom_break = True
|
| 5530 |
+
elif isinstance(e, RuntimeError):
|
| 5531 |
+
err_msg = str(e)
|
| 5532 |
+
if "CUDA" in err_msg or "MPS" in err_msg or "illegal" in err_msg.lower():
|
| 5533 |
+
self._free_gpu_memory()
|
| 5534 |
+
self.log(f" [batch {batch_start+1}-{batch_end}] device error — stopping")
|
| 5535 |
+
self.log(f" Skipping remaining refusal tests (device error: {err_msg[:120]})")
|
| 5536 |
+
oom_break = True
|
| 5537 |
+
else:
|
| 5538 |
+
raise
|
| 5539 |
else:
|
| 5540 |
raise
|
| 5541 |
|
obliteratus/analysis/sae_abliteration.py
CHANGED
|
@@ -39,6 +39,7 @@ from dataclasses import dataclass
|
|
| 39 |
|
| 40 |
import torch
|
| 41 |
import torch.nn as nn
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
@dataclass
|
|
@@ -120,11 +121,11 @@ def _auto_detect_device(device: str | None = None) -> str:
|
|
| 120 |
"""
|
| 121 |
if device is not None and device not in ("auto",):
|
| 122 |
return device
|
| 123 |
-
if
|
| 124 |
try:
|
| 125 |
-
free_mb =
|
| 126 |
if free_mb > 512:
|
| 127 |
-
return
|
| 128 |
except Exception:
|
| 129 |
pass
|
| 130 |
return "cpu"
|
|
|
|
| 39 |
|
| 40 |
import torch
|
| 41 |
import torch.nn as nn
|
| 42 |
+
from obliteratus import device as dev
|
| 43 |
|
| 44 |
|
| 45 |
@dataclass
|
|
|
|
| 121 |
"""
|
| 122 |
if device is not None and device not in ("auto",):
|
| 123 |
return device
|
| 124 |
+
if dev.is_gpu_available():
|
| 125 |
try:
|
| 126 |
+
free_mb = dev.get_total_free_gb() * 1024
|
| 127 |
if free_mb > 512:
|
| 128 |
+
return dev.get_device()
|
| 129 |
except Exception:
|
| 130 |
pass
|
| 131 |
return "cpu"
|
obliteratus/device.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified device abstraction for CUDA, MPS (Apple Silicon), and CPU.
|
| 2 |
+
|
| 3 |
+
All device-specific queries (availability, memory, cache management) go through
|
| 4 |
+
this module so the rest of the codebase never calls ``torch.cuda.*`` directly.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import platform
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Device detection
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def is_cuda() -> bool:
|
| 24 |
+
"""True when at least one NVIDIA CUDA GPU is visible."""
|
| 25 |
+
return torch.cuda.is_available()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_mps() -> bool:
|
| 29 |
+
"""True when Apple Metal Performance Shaders backend is usable."""
|
| 30 |
+
return (
|
| 31 |
+
hasattr(torch.backends, "mps")
|
| 32 |
+
and torch.backends.mps.is_available()
|
| 33 |
+
and torch.backends.mps.is_built()
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_gpu_available() -> bool:
|
| 38 |
+
"""True if *any* GPU backend (CUDA or MPS) is available."""
|
| 39 |
+
return is_cuda() or is_mps()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_device(preference: str = "auto") -> str:
|
| 43 |
+
"""Resolve a device string.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
preference : str
|
| 48 |
+
``"auto"`` picks the best GPU, ``"cuda"``/``"mps"``/``"cpu"`` forces.
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
str
|
| 53 |
+
A PyTorch device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
|
| 54 |
+
"""
|
| 55 |
+
if preference == "auto":
|
| 56 |
+
if is_cuda():
|
| 57 |
+
return "cuda"
|
| 58 |
+
if is_mps():
|
| 59 |
+
return "mps"
|
| 60 |
+
return "cpu"
|
| 61 |
+
return preference
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_device_name() -> str:
|
| 65 |
+
"""Human-readable name of the current accelerator."""
|
| 66 |
+
if is_cuda():
|
| 67 |
+
return torch.cuda.get_device_name(0)
|
| 68 |
+
if is_mps():
|
| 69 |
+
# Apple doesn't expose a per-chip name via MPS; use platform info.
|
| 70 |
+
chip = platform.processor() or "Apple Silicon"
|
| 71 |
+
return f"Apple {chip} (MPS)"
|
| 72 |
+
return "CPU"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def device_count() -> int:
|
| 76 |
+
"""Number of accelerator devices (GPUs or MPS slots)."""
|
| 77 |
+
if is_cuda():
|
| 78 |
+
return torch.cuda.device_count()
|
| 79 |
+
if is_mps():
|
| 80 |
+
return 1 # MPS always exposes a single unified device
|
| 81 |
+
return 0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Memory information
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class MemoryInfo:
|
| 90 |
+
"""Snapshot of accelerator memory (in GB)."""
|
| 91 |
+
|
| 92 |
+
used_gb: float = 0.0
|
| 93 |
+
reserved_gb: float = 0.0
|
| 94 |
+
total_gb: float = 0.0
|
| 95 |
+
free_gb: float = 0.0
|
| 96 |
+
device_name: str = "CPU"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _system_memory_gb() -> tuple[float, float]:
|
| 100 |
+
"""Return (total_gb, available_gb) of system RAM."""
|
| 101 |
+
try:
|
| 102 |
+
import psutil
|
| 103 |
+
vm = psutil.virtual_memory()
|
| 104 |
+
return vm.total / 1024 ** 3, vm.available / 1024 ** 3
|
| 105 |
+
except ImportError:
|
| 106 |
+
pass
|
| 107 |
+
try:
|
| 108 |
+
total = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / 1024 ** 3
|
| 109 |
+
# Rough estimate: assume 60 % available if we can't query
|
| 110 |
+
return total, total * 0.6
|
| 111 |
+
except (AttributeError, ValueError):
|
| 112 |
+
return 16.0, 8.0 # conservative fallback
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_memory_info(device_index: int = 0) -> MemoryInfo:
|
| 116 |
+
"""Query memory for the given accelerator (or system RAM for MPS/CPU)."""
|
| 117 |
+
name = get_device_name()
|
| 118 |
+
|
| 119 |
+
if is_cuda():
|
| 120 |
+
try:
|
| 121 |
+
free, total = torch.cuda.mem_get_info(device_index)
|
| 122 |
+
used = torch.cuda.memory_allocated(device_index)
|
| 123 |
+
reserved = torch.cuda.memory_reserved(device_index)
|
| 124 |
+
total_gb = total / 1024 ** 3
|
| 125 |
+
return MemoryInfo(
|
| 126 |
+
used_gb=used / 1024 ** 3,
|
| 127 |
+
reserved_gb=reserved / 1024 ** 3,
|
| 128 |
+
total_gb=total_gb,
|
| 129 |
+
free_gb=free / 1024 ** 3,
|
| 130 |
+
device_name=name,
|
| 131 |
+
)
|
| 132 |
+
except Exception:
|
| 133 |
+
props = torch.cuda.get_device_properties(device_index)
|
| 134 |
+
total_gb = props.total_memory / 1024 ** 3
|
| 135 |
+
return MemoryInfo(total_gb=total_gb, free_gb=total_gb, device_name=name)
|
| 136 |
+
|
| 137 |
+
if is_mps():
|
| 138 |
+
# MPS uses unified memory — report system RAM as a proxy.
|
| 139 |
+
total, avail = _system_memory_gb()
|
| 140 |
+
# Apple's unified memory is shared with the OS, so usable fraction
|
| 141 |
+
# is typically ~65-75 % of total.
|
| 142 |
+
usable = total * 0.70
|
| 143 |
+
return MemoryInfo(
|
| 144 |
+
used_gb=max(usable - avail, 0.0),
|
| 145 |
+
reserved_gb=0.0,
|
| 146 |
+
total_gb=usable,
|
| 147 |
+
free_gb=min(avail, usable),
|
| 148 |
+
device_name=name,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# CPU-only
|
| 152 |
+
total, avail = _system_memory_gb()
|
| 153 |
+
return MemoryInfo(total_gb=total, free_gb=avail, device_name=name)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_total_free_gb() -> float:
|
| 157 |
+
"""Sum of free memory across all accelerator devices, in GB."""
|
| 158 |
+
if is_cuda():
|
| 159 |
+
total_free = 0.0
|
| 160 |
+
for i in range(torch.cuda.device_count()):
|
| 161 |
+
try:
|
| 162 |
+
free, _ = torch.cuda.mem_get_info(i)
|
| 163 |
+
total_free += free / 1024 ** 3
|
| 164 |
+
except Exception:
|
| 165 |
+
props = torch.cuda.get_device_properties(i)
|
| 166 |
+
total_free += props.total_memory / 1024 ** 3
|
| 167 |
+
return total_free
|
| 168 |
+
if is_mps():
|
| 169 |
+
_, avail = _system_memory_gb()
|
| 170 |
+
return avail * 0.70 # usable fraction
|
| 171 |
+
return 0.0
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Cache / memory management
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
def empty_cache() -> None:
|
| 179 |
+
"""Release cached allocations on the current accelerator."""
|
| 180 |
+
if is_cuda():
|
| 181 |
+
torch.cuda.empty_cache()
|
| 182 |
+
elif is_mps():
|
| 183 |
+
# torch.mps.empty_cache() available since PyTorch 2.1
|
| 184 |
+
if hasattr(torch.mps, "empty_cache"):
|
| 185 |
+
torch.mps.empty_cache()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def free_gpu_memory() -> None:
|
| 189 |
+
"""Aggressive memory cleanup: GC + accelerator cache flush."""
|
| 190 |
+
gc.collect()
|
| 191 |
+
if is_cuda():
|
| 192 |
+
try:
|
| 193 |
+
torch.cuda.empty_cache()
|
| 194 |
+
except Exception:
|
| 195 |
+
try:
|
| 196 |
+
torch.cuda.synchronize()
|
| 197 |
+
except Exception:
|
| 198 |
+
pass
|
| 199 |
+
try:
|
| 200 |
+
torch.cuda.reset_peak_memory_stats()
|
| 201 |
+
except Exception:
|
| 202 |
+
pass
|
| 203 |
+
elif is_mps():
|
| 204 |
+
if hasattr(torch.mps, "empty_cache"):
|
| 205 |
+
try:
|
| 206 |
+
torch.mps.empty_cache()
|
| 207 |
+
except Exception:
|
| 208 |
+
pass
|
| 209 |
+
if hasattr(torch.mps, "synchronize"):
|
| 210 |
+
try:
|
| 211 |
+
torch.mps.synchronize()
|
| 212 |
+
except Exception:
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def set_seed_all(seed: int) -> None:
|
| 217 |
+
"""Set random seed on all available accelerators."""
|
| 218 |
+
torch.manual_seed(seed)
|
| 219 |
+
if is_cuda():
|
| 220 |
+
torch.cuda.manual_seed_all(seed)
|
| 221 |
+
# MPS shares the CPU random state — no separate seed call needed.
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------------------------------------------------------------------------
|
| 225 |
+
# Dtype helpers
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
def default_dtype(device: str | None = None) -> torch.dtype:
|
| 229 |
+
"""Sensible default dtype for the given device."""
|
| 230 |
+
dev = device or get_device()
|
| 231 |
+
if dev == "cpu":
|
| 232 |
+
return torch.float32
|
| 233 |
+
return torch.float16
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def supports_bfloat16(device: str | None = None) -> bool:
|
| 237 |
+
"""Whether *bfloat16* is natively supported on the target device."""
|
| 238 |
+
dev = device or get_device()
|
| 239 |
+
if dev.startswith("cuda"):
|
| 240 |
+
if is_cuda():
|
| 241 |
+
major, _ = torch.cuda.get_device_capability(0)
|
| 242 |
+
return major >= 8 # Ampere+
|
| 243 |
+
return False
|
| 244 |
+
if dev == "mps":
|
| 245 |
+
# MPS added bfloat16 support in PyTorch 2.3+
|
| 246 |
+
return hasattr(torch, "__version__") and tuple(
|
| 247 |
+
int(x) for x in torch.__version__.split(".")[:2]
|
| 248 |
+
) >= (2, 3)
|
| 249 |
+
return True # CPU supports bfloat16 on most modern platforms
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def supports_float64(device: str | None = None) -> bool:
|
| 253 |
+
"""Whether *float64* is supported (MPS does NOT support it)."""
|
| 254 |
+
dev = device or get_device()
|
| 255 |
+
return dev != "mps"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def safe_svd_dtype(tensor: torch.Tensor) -> torch.dtype:
|
| 259 |
+
"""Return a dtype safe for SVD on the tensor's device.
|
| 260 |
+
|
| 261 |
+
MPS does not support float64, so we cap at float32.
|
| 262 |
+
"""
|
| 263 |
+
if tensor.device.type == "mps":
|
| 264 |
+
return torch.float32
|
| 265 |
+
return torch.float64 if tensor.dtype in (torch.float64, torch.float32) else torch.float32
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
# OOM exception matching
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
|
| 272 |
+
def is_oom_error(exc: BaseException) -> bool:
|
| 273 |
+
"""Return True if *exc* is an out-of-memory error on any backend."""
|
| 274 |
+
if isinstance(exc, torch.cuda.OutOfMemoryError):
|
| 275 |
+
return True
|
| 276 |
+
# MPS raises a generic RuntimeError containing "out of memory"
|
| 277 |
+
if isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower():
|
| 278 |
+
return True
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
# Quantization compatibility
|
| 284 |
+
# ---------------------------------------------------------------------------
|
| 285 |
+
|
| 286 |
+
def supports_bitsandbytes(device: str | None = None) -> bool:
|
| 287 |
+
"""BitsAndBytes requires NVIDIA CUDA — check that."""
|
| 288 |
+
dev = device or get_device()
|
| 289 |
+
return dev.startswith("cuda")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def supports_device_map_auto(device: str | None = None) -> bool:
|
| 293 |
+
"""Accelerate's device_map='auto' is only reliable on CUDA."""
|
| 294 |
+
dev = device or get_device()
|
| 295 |
+
return dev.startswith("cuda")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ---------------------------------------------------------------------------
|
| 299 |
+
# CUDA env setup (called once at import time of abliterate.py)
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
|
| 302 |
+
def configure_cuda_alloc() -> None:
|
| 303 |
+
"""Set expandable_segments for CUDA if available."""
|
| 304 |
+
if is_cuda() and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
|
| 305 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
obliteratus/evaluation/benchmarks.py
CHANGED
|
@@ -26,6 +26,7 @@ import re
|
|
| 26 |
from dataclasses import dataclass, field
|
| 27 |
|
| 28 |
import torch
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
@dataclass
|
|
@@ -261,8 +262,7 @@ class BenchmarkRunner:
|
|
| 261 |
("math_reasoning", self.run_math_reasoning_probe)]:
|
| 262 |
results[name] = fn()
|
| 263 |
# Free KV caches between probes to prevent OOM on tight GPUs
|
| 264 |
-
|
| 265 |
-
torch.cuda.empty_cache()
|
| 266 |
return results
|
| 267 |
|
| 268 |
def _answer_mcq(self, question: str, choices: list[str]) -> int:
|
|
|
|
| 26 |
from dataclasses import dataclass, field
|
| 27 |
|
| 28 |
import torch
|
| 29 |
+
from obliteratus import device as dev
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
|
|
|
| 262 |
("math_reasoning", self.run_math_reasoning_probe)]:
|
| 263 |
results[name] = fn()
|
| 264 |
# Free KV caches between probes to prevent OOM on tight GPUs
|
| 265 |
+
dev.empty_cache()
|
|
|
|
| 266 |
return results
|
| 267 |
|
| 268 |
def _answer_mcq(self, question: str, choices: list[str]) -> int:
|
obliteratus/evaluation/heretic_eval.py
CHANGED
|
@@ -32,6 +32,7 @@ from typing import TYPE_CHECKING
|
|
| 32 |
|
| 33 |
import torch
|
| 34 |
import torch.nn.functional as F
|
|
|
|
| 35 |
|
| 36 |
if TYPE_CHECKING:
|
| 37 |
from collections.abc import Callable
|
|
@@ -363,8 +364,7 @@ def unload_harmbench_classifier() -> None:
|
|
| 363 |
model, tokenizer = _HARMBENCH_CLASSIFIER
|
| 364 |
del model, tokenizer
|
| 365 |
_HARMBENCH_CLASSIFIER = None
|
| 366 |
-
|
| 367 |
-
torch.cuda.empty_cache()
|
| 368 |
logger.info("HarmBench classifier unloaded")
|
| 369 |
|
| 370 |
|
|
@@ -432,8 +432,7 @@ def harmbench_asr(
|
|
| 432 |
|
| 433 |
# Free memory between batches
|
| 434 |
del inputs, outputs
|
| 435 |
-
|
| 436 |
-
torch.cuda.empty_cache()
|
| 437 |
|
| 438 |
n_successful = sum(per_item)
|
| 439 |
return {
|
|
@@ -536,8 +535,7 @@ def first_token_kl_on_prompts(
|
|
| 536 |
kl_values.extend(kl.cpu().tolist())
|
| 537 |
|
| 538 |
del inputs_orig, inputs_mod, logits_orig, logits_mod, first_orig, first_mod
|
| 539 |
-
|
| 540 |
-
torch.cuda.empty_cache()
|
| 541 |
|
| 542 |
mean_kl = statistics.mean(kl_values) if kl_values else 0.0
|
| 543 |
std_kl = statistics.stdev(kl_values) if len(kl_values) > 1 else 0.0
|
|
@@ -1098,8 +1096,8 @@ def run_full_heretic_eval(
|
|
| 1098 |
completions.append("")
|
| 1099 |
|
| 1100 |
del inputs
|
| 1101 |
-
if i % 25 == 0
|
| 1102 |
-
|
| 1103 |
|
| 1104 |
log(f"Generated {len(completions)} completions")
|
| 1105 |
|
|
|
|
| 32 |
|
| 33 |
import torch
|
| 34 |
import torch.nn.functional as F
|
| 35 |
+
from obliteratus import device as dev
|
| 36 |
|
| 37 |
if TYPE_CHECKING:
|
| 38 |
from collections.abc import Callable
|
|
|
|
| 364 |
model, tokenizer = _HARMBENCH_CLASSIFIER
|
| 365 |
del model, tokenizer
|
| 366 |
_HARMBENCH_CLASSIFIER = None
|
| 367 |
+
dev.empty_cache()
|
|
|
|
| 368 |
logger.info("HarmBench classifier unloaded")
|
| 369 |
|
| 370 |
|
|
|
|
| 432 |
|
| 433 |
# Free memory between batches
|
| 434 |
del inputs, outputs
|
| 435 |
+
dev.empty_cache()
|
|
|
|
| 436 |
|
| 437 |
n_successful = sum(per_item)
|
| 438 |
return {
|
|
|
|
| 535 |
kl_values.extend(kl.cpu().tolist())
|
| 536 |
|
| 537 |
del inputs_orig, inputs_mod, logits_orig, logits_mod, first_orig, first_mod
|
| 538 |
+
dev.empty_cache()
|
|
|
|
| 539 |
|
| 540 |
mean_kl = statistics.mean(kl_values) if kl_values else 0.0
|
| 541 |
std_kl = statistics.stdev(kl_values) if len(kl_values) > 1 else 0.0
|
|
|
|
| 1096 |
completions.append("")
|
| 1097 |
|
| 1098 |
del inputs
|
| 1099 |
+
if i % 25 == 0:
|
| 1100 |
+
dev.empty_cache()
|
| 1101 |
|
| 1102 |
log(f"Generated {len(completions)} completions")
|
| 1103 |
|
obliteratus/interactive.py
CHANGED
|
@@ -21,9 +21,10 @@ console = Console()
|
|
| 21 |
def _detect_compute_tier() -> str:
|
| 22 |
"""Auto-detect the best compute tier based on available hardware."""
|
| 23 |
try:
|
| 24 |
-
import
|
| 25 |
|
| 26 |
-
if
|
|
|
|
| 27 |
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 28 |
if vram_gb >= 20:
|
| 29 |
return "large"
|
|
@@ -31,8 +32,13 @@ def _detect_compute_tier() -> str:
|
|
| 31 |
return "medium"
|
| 32 |
else:
|
| 33 |
return "small"
|
| 34 |
-
elif
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
except ImportError:
|
| 37 |
pass
|
| 38 |
return "tiny" # CPU only
|
|
@@ -237,12 +243,11 @@ def run_interactive():
|
|
| 237 |
dtype = model_preset.recommended_dtype
|
| 238 |
quantization = None
|
| 239 |
try:
|
| 240 |
-
import
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
device = "mps"
|
| 246 |
except ImportError:
|
| 247 |
pass
|
| 248 |
|
|
|
|
| 21 |
def _detect_compute_tier() -> str:
|
| 22 |
"""Auto-detect the best compute tier based on available hardware."""
|
| 23 |
try:
|
| 24 |
+
from obliteratus import device as dev
|
| 25 |
|
| 26 |
+
if dev.is_cuda():
|
| 27 |
+
import torch
|
| 28 |
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 29 |
if vram_gb >= 20:
|
| 30 |
return "large"
|
|
|
|
| 32 |
return "medium"
|
| 33 |
else:
|
| 34 |
return "small"
|
| 35 |
+
elif dev.is_mps():
|
| 36 |
+
# Apple Silicon with unified memory — estimate from system RAM
|
| 37 |
+
mem = dev.get_memory_info()
|
| 38 |
+
if mem.total_gb >= 24:
|
| 39 |
+
return "medium" # M1 Pro/Max/Ultra, M2 Pro/Max/Ultra, M3 Pro/Max
|
| 40 |
+
else:
|
| 41 |
+
return "small" # M1/M2/M3 base (8-16 GB)
|
| 42 |
except ImportError:
|
| 43 |
pass
|
| 44 |
return "tiny" # CPU only
|
|
|
|
| 243 |
dtype = model_preset.recommended_dtype
|
| 244 |
quantization = None
|
| 245 |
try:
|
| 246 |
+
from obliteratus import device as _dev
|
| 247 |
|
| 248 |
+
resolved = _dev.get_device()
|
| 249 |
+
if resolved != "cpu":
|
| 250 |
+
device = resolved if resolved == "mps" else "auto"
|
|
|
|
| 251 |
except ImportError:
|
| 252 |
pass
|
| 253 |
|
obliteratus/mlx_backend.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional MLX backend for Apple Silicon native inference and weight editing.
|
| 2 |
+
|
| 3 |
+
MLX is Apple's array framework that runs natively on the Apple Neural Engine
|
| 4 |
+
and Metal GPU. When available, it provides significantly faster inference and
|
| 5 |
+
weight manipulation than PyTorch's MPS backend on Apple hardware.
|
| 6 |
+
|
| 7 |
+
This module is entirely optional — if ``mlx`` / ``mlx-lm`` are not installed,
|
| 8 |
+
``MLX_AVAILABLE`` is ``False`` and all public functions raise ``RuntimeError``.
|
| 9 |
+
|
| 10 |
+
Install with::
|
| 11 |
+
|
| 12 |
+
pip install mlx>=0.22 mlx-lm>=0.20
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Callable
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Availability check
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
MLX_AVAILABLE = False
|
| 28 |
+
_mx = None # mlx module
|
| 29 |
+
_mlx_lm = None # mlx-lm module
|
| 30 |
+
_mlx_nn = None # mlx.nn module
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import mlx.core as _mx_core # type: ignore[import-untyped]
|
| 34 |
+
import mlx.nn as _mlx_nn_mod # type: ignore[import-untyped]
|
| 35 |
+
import mlx_lm # type: ignore[import-untyped]
|
| 36 |
+
|
| 37 |
+
_mx = _mx_core
|
| 38 |
+
_mlx_nn = _mlx_nn_mod
|
| 39 |
+
_mlx_lm = mlx_lm
|
| 40 |
+
MLX_AVAILABLE = True
|
| 41 |
+
logger.info("MLX backend available (mlx %s)", _mx.__version__ if hasattr(_mx, "__version__") else "?")
|
| 42 |
+
except ImportError:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _require_mlx() -> None:
|
| 47 |
+
if not MLX_AVAILABLE:
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
"MLX backend is not available. Install with: pip install mlx>=0.22 mlx-lm>=0.20"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Model loading
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
class MLXModelHandle:
|
| 58 |
+
"""Lightweight wrapper around an MLX-loaded model + tokenizer."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, model: Any, tokenizer: Any, model_name: str):
|
| 61 |
+
self.model = model
|
| 62 |
+
self.tokenizer = tokenizer
|
| 63 |
+
self.model_name = model_name
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def config(self) -> Any:
|
| 67 |
+
return getattr(self.model, "config", None)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_model(
|
| 71 |
+
model_name: str,
|
| 72 |
+
dtype: str = "float16",
|
| 73 |
+
) -> MLXModelHandle:
|
| 74 |
+
"""Load a HuggingFace model via ``mlx-lm`` for Apple-native execution.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
model_name : str
|
| 79 |
+
HuggingFace model identifier (e.g. ``"meta-llama/Llama-3.2-3B-Instruct"``).
|
| 80 |
+
dtype : str
|
| 81 |
+
One of ``"float16"``, ``"bfloat16"``, ``"float32"``.
|
| 82 |
+
|
| 83 |
+
Returns
|
| 84 |
+
-------
|
| 85 |
+
MLXModelHandle
|
| 86 |
+
Wrapper with ``.model`` and ``.tokenizer`` attributes.
|
| 87 |
+
"""
|
| 88 |
+
_require_mlx()
|
| 89 |
+
|
| 90 |
+
from mlx_lm import load # type: ignore[import-untyped]
|
| 91 |
+
|
| 92 |
+
logger.info("Loading %s via MLX (dtype=%s)", model_name, dtype)
|
| 93 |
+
model, tokenizer = load(model_name)
|
| 94 |
+
|
| 95 |
+
return MLXModelHandle(model=model, tokenizer=tokenizer, model_name=model_name)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
# Inference
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
def generate(
|
| 103 |
+
handle: MLXModelHandle,
|
| 104 |
+
prompt: str,
|
| 105 |
+
max_tokens: int = 256,
|
| 106 |
+
temperature: float = 0.7,
|
| 107 |
+
top_p: float = 0.9,
|
| 108 |
+
repetition_penalty: float | None = None,
|
| 109 |
+
) -> str:
|
| 110 |
+
"""Generate text using the MLX model.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
handle : MLXModelHandle
|
| 115 |
+
A loaded MLX model handle.
|
| 116 |
+
prompt : str
|
| 117 |
+
The input prompt string.
|
| 118 |
+
max_tokens : int
|
| 119 |
+
Maximum number of tokens to generate.
|
| 120 |
+
temperature : float
|
| 121 |
+
Sampling temperature.
|
| 122 |
+
top_p : float
|
| 123 |
+
Nucleus sampling threshold.
|
| 124 |
+
repetition_penalty : float or None
|
| 125 |
+
Repetition penalty (1.0 = no penalty).
|
| 126 |
+
|
| 127 |
+
Returns
|
| 128 |
+
-------
|
| 129 |
+
str
|
| 130 |
+
Generated text completion.
|
| 131 |
+
"""
|
| 132 |
+
_require_mlx()
|
| 133 |
+
|
| 134 |
+
from mlx_lm import generate as mlx_generate # type: ignore[import-untyped]
|
| 135 |
+
|
| 136 |
+
kwargs: dict[str, Any] = {
|
| 137 |
+
"max_tokens": max_tokens,
|
| 138 |
+
"temp": temperature,
|
| 139 |
+
"top_p": top_p,
|
| 140 |
+
}
|
| 141 |
+
if repetition_penalty is not None:
|
| 142 |
+
kwargs["repetition_penalty"] = repetition_penalty
|
| 143 |
+
|
| 144 |
+
return mlx_generate(
|
| 145 |
+
handle.model,
|
| 146 |
+
handle.tokenizer,
|
| 147 |
+
prompt=prompt,
|
| 148 |
+
**kwargs,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
# Activation extraction
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
def get_activations(
|
| 157 |
+
handle: MLXModelHandle,
|
| 158 |
+
prompts: list[str],
|
| 159 |
+
layer_indices: list[int],
|
| 160 |
+
max_length: int = 256,
|
| 161 |
+
) -> dict[int, list[Any]]:
|
| 162 |
+
"""Extract hidden-state activations from specified layers.
|
| 163 |
+
|
| 164 |
+
Uses MLX's computation graph to capture intermediate outputs.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
handle : MLXModelHandle
|
| 169 |
+
Loaded model.
|
| 170 |
+
prompts : list[str]
|
| 171 |
+
Input prompts to probe.
|
| 172 |
+
layer_indices : list[int]
|
| 173 |
+
Which transformer layers to capture.
|
| 174 |
+
max_length : int
|
| 175 |
+
Maximum sequence length for tokenization.
|
| 176 |
+
|
| 177 |
+
Returns
|
| 178 |
+
-------
|
| 179 |
+
dict[int, list[mlx.core.array]]
|
| 180 |
+
Mapping from layer index to list of activation arrays (one per prompt).
|
| 181 |
+
Each array has shape ``(hidden_dim,)`` — the last-token hidden state.
|
| 182 |
+
"""
|
| 183 |
+
_require_mlx()
|
| 184 |
+
import mlx.core as mx # type: ignore[import-untyped]
|
| 185 |
+
|
| 186 |
+
model = handle.model
|
| 187 |
+
tokenizer = handle.tokenizer
|
| 188 |
+
|
| 189 |
+
# Identify the transformer block list
|
| 190 |
+
layers = None
|
| 191 |
+
for attr in ("model.layers", "transformer.h", "gpt_neox.layers"):
|
| 192 |
+
obj = model
|
| 193 |
+
try:
|
| 194 |
+
for part in attr.split("."):
|
| 195 |
+
obj = getattr(obj, part)
|
| 196 |
+
layers = obj
|
| 197 |
+
break
|
| 198 |
+
except AttributeError:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
if layers is None:
|
| 202 |
+
raise RuntimeError(
|
| 203 |
+
"Cannot locate transformer layers in the MLX model. "
|
| 204 |
+
"Supported architectures: LLaMA, GPT-2, GPT-NeoX."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
activations: dict[int, list[Any]] = {idx: [] for idx in layer_indices}
|
| 208 |
+
target_set = set(layer_indices)
|
| 209 |
+
|
| 210 |
+
for prompt in prompts:
|
| 211 |
+
tokens = tokenizer.encode(prompt)
|
| 212 |
+
if len(tokens) > max_length:
|
| 213 |
+
tokens = tokens[:max_length]
|
| 214 |
+
|
| 215 |
+
input_ids = mx.array([tokens])
|
| 216 |
+
|
| 217 |
+
# Forward through embedding
|
| 218 |
+
if hasattr(model, "model"):
|
| 219 |
+
# LLaMA-style: model.model.embed_tokens
|
| 220 |
+
embed_module = model.model
|
| 221 |
+
elif hasattr(model, "transformer"):
|
| 222 |
+
embed_module = model.transformer
|
| 223 |
+
else:
|
| 224 |
+
embed_module = model
|
| 225 |
+
|
| 226 |
+
if hasattr(embed_module, "embed_tokens"):
|
| 227 |
+
h = embed_module.embed_tokens(input_ids)
|
| 228 |
+
elif hasattr(embed_module, "wte"):
|
| 229 |
+
h = embed_module.wte(input_ids)
|
| 230 |
+
else:
|
| 231 |
+
raise RuntimeError("Cannot find embedding layer in MLX model")
|
| 232 |
+
|
| 233 |
+
# Walk through layers, capturing activations at target indices
|
| 234 |
+
for i, layer in enumerate(layers):
|
| 235 |
+
h = layer(h)
|
| 236 |
+
# Some layers return tuples (hidden, attention) — take first
|
| 237 |
+
if isinstance(h, tuple):
|
| 238 |
+
h = h[0]
|
| 239 |
+
|
| 240 |
+
if i in target_set:
|
| 241 |
+
# Last token hidden state
|
| 242 |
+
last_hidden = h[0, -1, :]
|
| 243 |
+
mx.eval(last_hidden) # force evaluation
|
| 244 |
+
activations[i].append(last_hidden)
|
| 245 |
+
|
| 246 |
+
return activations
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# Weight manipulation
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
def get_weight(handle: MLXModelHandle, layer_idx: int, param_path: str) -> Any:
|
| 254 |
+
"""Retrieve a weight matrix from the model.
|
| 255 |
+
|
| 256 |
+
Parameters
|
| 257 |
+
----------
|
| 258 |
+
handle : MLXModelHandle
|
| 259 |
+
Loaded model.
|
| 260 |
+
layer_idx : int
|
| 261 |
+
Transformer layer index.
|
| 262 |
+
param_path : str
|
| 263 |
+
Dot-separated path within the layer, e.g. ``"self_attn.o_proj.weight"``.
|
| 264 |
+
|
| 265 |
+
Returns
|
| 266 |
+
-------
|
| 267 |
+
mlx.core.array
|
| 268 |
+
The weight tensor.
|
| 269 |
+
"""
|
| 270 |
+
_require_mlx()
|
| 271 |
+
model = handle.model
|
| 272 |
+
|
| 273 |
+
# Navigate to the layer
|
| 274 |
+
layers = _get_layers(model)
|
| 275 |
+
layer = layers[layer_idx]
|
| 276 |
+
|
| 277 |
+
# Navigate the param path
|
| 278 |
+
obj = layer
|
| 279 |
+
for part in param_path.split("."):
|
| 280 |
+
obj = getattr(obj, part)
|
| 281 |
+
|
| 282 |
+
return obj
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def modify_weights(
|
| 286 |
+
handle: MLXModelHandle,
|
| 287 |
+
layer_idx: int,
|
| 288 |
+
param_path: str,
|
| 289 |
+
modifier_fn: Callable[[Any], Any],
|
| 290 |
+
) -> None:
|
| 291 |
+
"""Modify a weight matrix in-place using a function.
|
| 292 |
+
|
| 293 |
+
Parameters
|
| 294 |
+
----------
|
| 295 |
+
handle : MLXModelHandle
|
| 296 |
+
Loaded model.
|
| 297 |
+
layer_idx : int
|
| 298 |
+
Transformer layer index.
|
| 299 |
+
param_path : str
|
| 300 |
+
Dot-separated path within the layer to the weight, e.g.
|
| 301 |
+
``"self_attn.o_proj.weight"``.
|
| 302 |
+
modifier_fn : callable
|
| 303 |
+
Function that takes the current weight (mlx array) and returns the
|
| 304 |
+
modified weight (mlx array). For abliteration, this would project
|
| 305 |
+
out the refusal direction.
|
| 306 |
+
"""
|
| 307 |
+
_require_mlx()
|
| 308 |
+
import mlx.core as mx # type: ignore[import-untyped]
|
| 309 |
+
|
| 310 |
+
model = handle.model
|
| 311 |
+
layers = _get_layers(model)
|
| 312 |
+
layer = layers[layer_idx]
|
| 313 |
+
|
| 314 |
+
# Navigate to the parent module and leaf attribute
|
| 315 |
+
parts = param_path.split(".")
|
| 316 |
+
parent = layer
|
| 317 |
+
for part in parts[:-1]:
|
| 318 |
+
parent = getattr(parent, part)
|
| 319 |
+
leaf_name = parts[-1]
|
| 320 |
+
|
| 321 |
+
old_weight = getattr(parent, leaf_name)
|
| 322 |
+
new_weight = modifier_fn(old_weight)
|
| 323 |
+
|
| 324 |
+
# MLX uses a functional update pattern
|
| 325 |
+
if hasattr(parent, "update"):
|
| 326 |
+
parent.update({leaf_name: new_weight})
|
| 327 |
+
else:
|
| 328 |
+
setattr(parent, leaf_name, new_weight)
|
| 329 |
+
|
| 330 |
+
mx.eval(new_weight) # materialize
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def project_out_direction(weight: Any, direction: Any) -> Any:
|
| 334 |
+
"""Project a direction out of a weight matrix (abliteration).
|
| 335 |
+
|
| 336 |
+
Given weight matrix W and unit direction d, computes::
|
| 337 |
+
|
| 338 |
+
W' = W - (W @ d) outer d
|
| 339 |
+
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
weight : mlx.core.array
|
| 343 |
+
Weight matrix, shape ``(out_features, in_features)``.
|
| 344 |
+
direction : mlx.core.array
|
| 345 |
+
Unit direction vector, shape ``(in_features,)``.
|
| 346 |
+
|
| 347 |
+
Returns
|
| 348 |
+
-------
|
| 349 |
+
mlx.core.array
|
| 350 |
+
Modified weight with direction projected out.
|
| 351 |
+
"""
|
| 352 |
+
_require_mlx()
|
| 353 |
+
import mlx.core as mx # type: ignore[import-untyped]
|
| 354 |
+
|
| 355 |
+
d = direction.astype(weight.dtype)
|
| 356 |
+
# W @ d gives the component along d for each row
|
| 357 |
+
proj = mx.matmul(weight, d[:, None]) # (out, 1)
|
| 358 |
+
return weight - mx.matmul(proj, d[None, :]) # (out, in)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
# Save model
|
| 363 |
+
# ---------------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def save_model(
|
| 366 |
+
handle: MLXModelHandle,
|
| 367 |
+
output_dir: str | Path,
|
| 368 |
+
upload_repo: str | None = None,
|
| 369 |
+
) -> Path:
|
| 370 |
+
"""Save the (modified) MLX model to disk.
|
| 371 |
+
|
| 372 |
+
Saves in safetensors format compatible with both MLX and HuggingFace.
|
| 373 |
+
|
| 374 |
+
Parameters
|
| 375 |
+
----------
|
| 376 |
+
handle : MLXModelHandle
|
| 377 |
+
Model handle (possibly with modified weights).
|
| 378 |
+
output_dir : str or Path
|
| 379 |
+
Directory to save into.
|
| 380 |
+
upload_repo : str or None
|
| 381 |
+
If set, also uploads to HuggingFace Hub.
|
| 382 |
+
|
| 383 |
+
Returns
|
| 384 |
+
-------
|
| 385 |
+
Path
|
| 386 |
+
The output directory.
|
| 387 |
+
"""
|
| 388 |
+
_require_mlx()
|
| 389 |
+
|
| 390 |
+
from mlx_lm import convert # type: ignore[import-untyped]
|
| 391 |
+
|
| 392 |
+
out = Path(output_dir)
|
| 393 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
|
| 395 |
+
# mlx-lm's save uses safetensors
|
| 396 |
+
if hasattr(_mlx_lm, "save_model"):
|
| 397 |
+
_mlx_lm.save_model(str(out), handle.model, handle.tokenizer)
|
| 398 |
+
else:
|
| 399 |
+
# Fallback: manual save via mlx.core.save_safetensors
|
| 400 |
+
import mlx.core as mx # type: ignore[import-untyped]
|
| 401 |
+
weights = dict(handle.model.parameters())
|
| 402 |
+
flat = {}
|
| 403 |
+
_flatten_dict(weights, "", flat)
|
| 404 |
+
mx.save_safetensors(str(out / "model.safetensors"), flat)
|
| 405 |
+
# Save tokenizer via transformers
|
| 406 |
+
handle.tokenizer.save_pretrained(str(out))
|
| 407 |
+
|
| 408 |
+
logger.info("MLX model saved to %s", out)
|
| 409 |
+
|
| 410 |
+
if upload_repo:
|
| 411 |
+
try:
|
| 412 |
+
from mlx_lm import upload_to_hub # type: ignore[import-untyped]
|
| 413 |
+
upload_to_hub(str(out), upload_repo)
|
| 414 |
+
logger.info("Uploaded to %s", upload_repo)
|
| 415 |
+
except (ImportError, AttributeError):
|
| 416 |
+
logger.warning("mlx-lm upload not available — push manually with huggingface-cli")
|
| 417 |
+
|
| 418 |
+
return out
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# ---------------------------------------------------------------------------
|
| 422 |
+
# Conversion: PyTorch ↔ MLX
|
| 423 |
+
# ---------------------------------------------------------------------------
|
| 424 |
+
|
| 425 |
+
def torch_tensor_to_mlx(tensor: "torch.Tensor") -> Any: # noqa: F821
|
| 426 |
+
"""Convert a PyTorch tensor to an MLX array."""
|
| 427 |
+
_require_mlx()
|
| 428 |
+
import mlx.core as mx # type: ignore[import-untyped]
|
| 429 |
+
import numpy as np
|
| 430 |
+
|
| 431 |
+
# Move to CPU and convert via numpy
|
| 432 |
+
np_array = tensor.detach().cpu().float().numpy()
|
| 433 |
+
return mx.array(np_array)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def mlx_to_torch_tensor(array: Any, device: str = "cpu") -> "torch.Tensor": # noqa: F821
|
| 437 |
+
"""Convert an MLX array to a PyTorch tensor."""
|
| 438 |
+
import numpy as np
|
| 439 |
+
import torch
|
| 440 |
+
|
| 441 |
+
np_array = np.array(array)
|
| 442 |
+
return torch.from_numpy(np_array).to(device)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# ---------------------------------------------------------------------------
|
| 446 |
+
# Internal helpers
|
| 447 |
+
# ---------------------------------------------------------------------------
|
| 448 |
+
|
| 449 |
+
def _get_layers(model: Any) -> Any:
|
| 450 |
+
"""Locate the transformer block list in an MLX model."""
|
| 451 |
+
for attr_path in ("model.layers", "transformer.h", "gpt_neox.layers"):
|
| 452 |
+
obj = model
|
| 453 |
+
try:
|
| 454 |
+
for part in attr_path.split("."):
|
| 455 |
+
obj = getattr(obj, part)
|
| 456 |
+
return obj
|
| 457 |
+
except AttributeError:
|
| 458 |
+
continue
|
| 459 |
+
raise RuntimeError("Cannot locate transformer layers in MLX model")
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _flatten_dict(d: dict, prefix: str, out: dict) -> None:
|
| 463 |
+
"""Flatten a nested dict with dot-separated keys."""
|
| 464 |
+
for k, v in d.items():
|
| 465 |
+
key = f"{prefix}{k}" if prefix else k
|
| 466 |
+
if isinstance(v, dict):
|
| 467 |
+
_flatten_dict(v, f"{key}.", out)
|
| 468 |
+
else:
|
| 469 |
+
out[key] = v
|
obliteratus/models/loader.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import Optional
|
|
| 12 |
import sys as _sys
|
| 13 |
|
| 14 |
import torch
|
|
|
|
| 15 |
from transformers import (
|
| 16 |
AutoConfig,
|
| 17 |
AutoModelForCausalLM,
|
|
@@ -381,24 +382,8 @@ def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
|
|
| 381 |
|
| 382 |
|
| 383 |
def _available_gpu_memory_gb() -> float:
|
| 384 |
-
"""Return free
|
| 385 |
-
|
| 386 |
-
Uses torch.cuda.mem_get_info which reports actual free memory,
|
| 387 |
-
not total capacity. Falls back to total_memory if mem_get_info
|
| 388 |
-
is unavailable (PyTorch < 1.10).
|
| 389 |
-
"""
|
| 390 |
-
if not torch.cuda.is_available():
|
| 391 |
-
return 0.0
|
| 392 |
-
total_free = 0.0
|
| 393 |
-
for i in range(torch.cuda.device_count()):
|
| 394 |
-
try:
|
| 395 |
-
free, _ = torch.cuda.mem_get_info(i)
|
| 396 |
-
total_free += free / (1024 ** 3)
|
| 397 |
-
except AttributeError:
|
| 398 |
-
# Fallback for old PyTorch without mem_get_info
|
| 399 |
-
props = torch.cuda.get_device_properties(i)
|
| 400 |
-
total_free += props.total_memory / (1024 ** 3)
|
| 401 |
-
return total_free
|
| 402 |
|
| 403 |
|
| 404 |
def _hf_token() -> str | None:
|
|
@@ -515,34 +500,54 @@ def load_model(
|
|
| 515 |
load_kwargs.pop("torch_dtype", None)
|
| 516 |
load_kwargs["device_map"] = "auto"
|
| 517 |
elif quantization in ("4bit", "8bit"):
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
from transformers import BitsAndBytesConfig
|
| 526 |
-
|
| 527 |
-
# Enable fp32 CPU offload so that models too large to fit entirely on
|
| 528 |
-
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
| 529 |
-
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
| 530 |
-
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
| 531 |
-
if quantization == "4bit":
|
| 532 |
-
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 533 |
-
load_in_4bit=True,
|
| 534 |
-
bnb_4bit_compute_dtype=torch_dtype,
|
| 535 |
-
bnb_4bit_quant_type="nf4",
|
| 536 |
-
llm_int8_enable_fp32_cpu_offload=True,
|
| 537 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
else:
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
|
| 548 |
_offload_dir = None
|
|
@@ -560,9 +565,9 @@ def load_model(
|
|
| 560 |
# Reserve GPU headroom for inference (KV cache, activations, generate()).
|
| 561 |
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
|
| 562 |
# no room for forward passes or generation on tight-memory setups.
|
| 563 |
-
if
|
| 564 |
max_memory = {}
|
| 565 |
-
for i in range(
|
| 566 |
total = torch.cuda.get_device_properties(i).total_memory
|
| 567 |
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
|
| 568 |
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
|
|
@@ -570,16 +575,8 @@ def load_model(
|
|
| 570 |
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
|
| 571 |
# Allow overflow to CPU RAM, capped at 85% of physical memory
|
| 572 |
# to leave room for the OS, Python runtime, and serialization buffers.
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
cpu_ram_gb = psutil.virtual_memory().total / (1024 ** 3)
|
| 576 |
-
except ImportError:
|
| 577 |
-
try:
|
| 578 |
-
cpu_ram_gb = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / (1024 ** 3)
|
| 579 |
-
except (AttributeError, ValueError):
|
| 580 |
-
# os.sysconf is unavailable on non-POSIX platforms (Windows)
|
| 581 |
-
cpu_ram_gb = 16.0 # conservative fallback
|
| 582 |
-
cpu_budget_gb = int(cpu_ram_gb * 0.85)
|
| 583 |
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
|
| 584 |
load_kwargs["max_memory"] = max_memory
|
| 585 |
logger.info(
|
|
@@ -625,12 +622,15 @@ def load_model(
|
|
| 625 |
|
| 626 |
if device not in ("auto",) and quantization is None and native_quant is None:
|
| 627 |
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
model.eval()
|
| 630 |
|
| 631 |
-
# Free
|
| 632 |
-
|
| 633 |
-
torch.cuda.empty_cache()
|
| 634 |
|
| 635 |
try:
|
| 636 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -665,9 +665,7 @@ def load_model(
|
|
| 665 |
if gpu_gb > 0 and native_quant is not None:
|
| 666 |
# Model is pre-quantized but we can't estimate its true size.
|
| 667 |
# Check actual free memory after loading — if less than 40% free, skip snapshot.
|
| 668 |
-
free_gb =
|
| 669 |
-
for i in range(torch.cuda.device_count()):
|
| 670 |
-
free_gb += torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
| 671 |
if free_gb < gpu_gb * 0.4:
|
| 672 |
logger.warning(
|
| 673 |
f"Auto-skipping state dict snapshot for natively quantized model "
|
|
|
|
| 12 |
import sys as _sys
|
| 13 |
|
| 14 |
import torch
|
| 15 |
+
from obliteratus import device as dev
|
| 16 |
from transformers import (
|
| 17 |
AutoConfig,
|
| 18 |
AutoModelForCausalLM,
|
|
|
|
| 382 |
|
| 383 |
|
| 384 |
def _available_gpu_memory_gb() -> float:
|
| 385 |
+
"""Return free accelerator memory in GB (CUDA, MPS, or 0 for CPU)."""
|
| 386 |
+
return dev.get_total_free_gb()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
|
| 389 |
def _hf_token() -> str | None:
|
|
|
|
| 500 |
load_kwargs.pop("torch_dtype", None)
|
| 501 |
load_kwargs["device_map"] = "auto"
|
| 502 |
elif quantization in ("4bit", "8bit"):
|
| 503 |
+
# BitsAndBytes only works on NVIDIA CUDA GPUs.
|
| 504 |
+
resolved_device = dev.get_device(device)
|
| 505 |
+
if not dev.supports_bitsandbytes(resolved_device):
|
| 506 |
+
logger.warning(
|
| 507 |
+
"BitsAndBytes quantization is not supported on %s. "
|
| 508 |
+
"Loading in %s instead.",
|
| 509 |
+
resolved_device, dtype,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
)
|
| 511 |
+
# On MPS, load normally to the device; on CPU, fall through.
|
| 512 |
+
if resolved_device == "mps":
|
| 513 |
+
device = "mps"
|
| 514 |
+
# Don't set quantization_config — fall through to normal loading.
|
| 515 |
else:
|
| 516 |
+
try:
|
| 517 |
+
import bitsandbytes # noqa: F401
|
| 518 |
+
except ImportError:
|
| 519 |
+
raise RuntimeError(
|
| 520 |
+
f"Quantization '{quantization}' requires bitsandbytes: "
|
| 521 |
+
f"pip install -U bitsandbytes>=0.46.1"
|
| 522 |
+
)
|
| 523 |
+
from transformers import BitsAndBytesConfig
|
| 524 |
+
|
| 525 |
+
# Enable fp32 CPU offload so that models too large to fit entirely on
|
| 526 |
+
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
| 527 |
+
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
| 528 |
+
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
| 529 |
+
if quantization == "4bit":
|
| 530 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 531 |
+
load_in_4bit=True,
|
| 532 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
| 533 |
+
bnb_4bit_quant_type="nf4",
|
| 534 |
+
llm_int8_enable_fp32_cpu_offload=True,
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 538 |
+
load_in_8bit=True,
|
| 539 |
+
llm_int8_enable_fp32_cpu_offload=True,
|
| 540 |
+
)
|
| 541 |
+
load_kwargs["device_map"] = "auto"
|
| 542 |
+
|
| 543 |
+
# device_map="auto" is only reliable on CUDA (accelerate doesn't support MPS).
|
| 544 |
+
if "device_map" not in load_kwargs and device == "auto":
|
| 545 |
+
resolved_device = dev.get_device(device)
|
| 546 |
+
if dev.supports_device_map_auto(resolved_device):
|
| 547 |
+
load_kwargs["device_map"] = "auto"
|
| 548 |
+
else:
|
| 549 |
+
# MPS / CPU: load to CPU first, then .to(device) after loading.
|
| 550 |
+
pass
|
| 551 |
|
| 552 |
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
|
| 553 |
_offload_dir = None
|
|
|
|
| 565 |
# Reserve GPU headroom for inference (KV cache, activations, generate()).
|
| 566 |
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
|
| 567 |
# no room for forward passes or generation on tight-memory setups.
|
| 568 |
+
if dev.is_cuda():
|
| 569 |
max_memory = {}
|
| 570 |
+
for i in range(dev.device_count()):
|
| 571 |
total = torch.cuda.get_device_properties(i).total_memory
|
| 572 |
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
|
| 573 |
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
|
|
|
|
| 575 |
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
|
| 576 |
# Allow overflow to CPU RAM, capped at 85% of physical memory
|
| 577 |
# to leave room for the OS, Python runtime, and serialization buffers.
|
| 578 |
+
total_ram, _ = dev._system_memory_gb()
|
| 579 |
+
cpu_budget_gb = int(total_ram * 0.85)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
|
| 581 |
load_kwargs["max_memory"] = max_memory
|
| 582 |
logger.info(
|
|
|
|
| 622 |
|
| 623 |
if device not in ("auto",) and quantization is None and native_quant is None:
|
| 624 |
model = model.to(device)
|
| 625 |
+
elif device == "auto" and not dev.supports_device_map_auto():
|
| 626 |
+
# MPS / CPU: device_map wasn't used, move model to best device.
|
| 627 |
+
resolved = dev.get_device()
|
| 628 |
+
model = model.to(resolved)
|
| 629 |
|
| 630 |
model.eval()
|
| 631 |
|
| 632 |
+
# Free accelerator cache after loading
|
| 633 |
+
dev.empty_cache()
|
|
|
|
| 634 |
|
| 635 |
try:
|
| 636 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 665 |
if gpu_gb > 0 and native_quant is not None:
|
| 666 |
# Model is pre-quantized but we can't estimate its true size.
|
| 667 |
# Check actual free memory after loading — if less than 40% free, skip snapshot.
|
| 668 |
+
free_gb = dev.get_total_free_gb()
|
|
|
|
|
|
|
| 669 |
if free_gb < gpu_gb * 0.4:
|
| 670 |
logger.warning(
|
| 671 |
f"Auto-skipping state dict snapshot for natively quantized model "
|
obliteratus/reproducibility.py
CHANGED
|
@@ -38,9 +38,9 @@ def set_seed(seed: int = 42, deterministic: bool = True) -> None:
|
|
| 38 |
|
| 39 |
try:
|
| 40 |
import torch
|
|
|
|
| 41 |
torch.manual_seed(seed)
|
| 42 |
-
|
| 43 |
-
torch.cuda.manual_seed_all(seed)
|
| 44 |
|
| 45 |
if deterministic:
|
| 46 |
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
|
|
| 38 |
|
| 39 |
try:
|
| 40 |
import torch
|
| 41 |
+
from obliteratus import device as dev
|
| 42 |
torch.manual_seed(seed)
|
| 43 |
+
dev.set_seed_all(seed)
|
|
|
|
| 44 |
|
| 45 |
if deterministic:
|
| 46 |
torch.use_deterministic_algorithms(True, warn_only=True)
|
obliteratus/tourney.py
CHANGED
|
@@ -372,9 +372,8 @@ class TourneyRunner:
|
|
| 372 |
# Clean up GPU between methods
|
| 373 |
gc.collect()
|
| 374 |
try:
|
| 375 |
-
import
|
| 376 |
-
|
| 377 |
-
torch.cuda.empty_cache()
|
| 378 |
except Exception:
|
| 379 |
pass
|
| 380 |
|
|
|
|
| 372 |
# Clean up GPU between methods
|
| 373 |
gc.collect()
|
| 374 |
try:
|
| 375 |
+
from obliteratus import device as dev
|
| 376 |
+
dev.empty_cache()
|
|
|
|
| 377 |
except Exception:
|
| 378 |
pass
|
| 379 |
|
requirements-apple.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optional Apple Silicon dependencies for native MLX acceleration.
|
| 2 |
+
# Install alongside the main requirements on macOS with Apple Silicon:
|
| 3 |
+
#
|
| 4 |
+
# pip install -r requirements.txt -r requirements-apple.txt
|
| 5 |
+
#
|
| 6 |
+
# These packages are macOS-only and will fail to install on Linux/Windows.
|
| 7 |
+
mlx>=0.22
|
| 8 |
+
mlx-lm>=0.20
|