pliny-the-prompter commited on
Commit
d169612
·
verified ·
1 Parent(s): b46e97f

Upload 134 files

Browse files
app.py CHANGED
@@ -57,6 +57,7 @@ if "HF_HOME" not in os.environ:
57
 
58
  import gradio as gr
59
  import torch
 
60
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
61
 
62
  # ── ZeroGPU support ─────────────────────────────────────────────────
@@ -399,6 +400,213 @@ def _validate_hub_repo(hub_repo: str) -> str:
399
  return ""
400
 
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  PROMPT_VOLUMES = {
403
  "33 (fast)": 33,
404
  "66 (better signal)": 66,
@@ -447,25 +655,11 @@ def _should_quantize(model_id: str, is_preset: bool = False) -> str | None:
447
  # ---------------------------------------------------------------------------
448
 
449
  def _clear_gpu():
450
- """Free GPU memory. Resilient to CUDA errors (e.g. after illegal memory access)."""
451
  with _lock:
452
  _state["model"] = None
453
  _state["tokenizer"] = None
454
- gc.collect()
455
- if torch.cuda.is_available():
456
- try:
457
- torch.cuda.empty_cache()
458
- except Exception:
459
- # CUDA context may be poisoned after an illegal-address error;
460
- # attempt a device reset so subsequent loads can succeed.
461
- try:
462
- torch.cuda.synchronize()
463
- except Exception:
464
- pass
465
- try:
466
- torch.cuda.reset_peak_memory_stats()
467
- except Exception:
468
- pass
469
 
470
 
471
  def _install_steering_hooks(model, steering_meta: dict) -> int:
@@ -589,16 +783,16 @@ def _cleanup_disk():
589
  # ---------------------------------------------------------------------------
590
 
591
  def _get_vram_html() -> str:
592
- """Return an HTML snippet showing GPU VRAM usage as a styled bar."""
593
- if not torch.cuda.is_available():
594
  return (
595
  '<div style="text-align:center;color:#4a5568;font-size:0.72rem;'
596
  'letter-spacing:1px;margin-top:6px;">CPU ONLY — NO GPU DETECTED</div>'
597
  )
598
  try:
599
- used = torch.cuda.memory_allocated() / 1024**3
600
- reserved = torch.cuda.memory_reserved() / 1024**3
601
- total = torch.cuda.get_device_properties(0).total_memory / 1024**3
602
  pct = (used / total * 100) if total > 0 else 0
603
  # Color shifts from green → yellow → red
604
  if pct < 50:
@@ -607,12 +801,17 @@ def _get_vram_html() -> str:
607
  bar_color = "#ffcc00"
608
  else:
609
  bar_color = "#ff003c"
610
- device_name = torch.cuda.get_device_name(0)
 
 
 
 
 
611
  return (
612
  f'<div style="margin:6px auto 0;max-width:480px;">'
613
  f'<div style="display:flex;justify-content:space-between;font-size:0.68rem;'
614
  f'color:#4a5568;letter-spacing:1px;margin-bottom:2px;">'
615
- f'<span>GPU: {device_name}</span>'
616
  f'<span>{used:.1f} / {total:.1f} GB ({pct:.0f}%)</span></div>'
617
  f'<div style="background:#0a0a0f;border:1px solid #1a1f2e;border-radius:3px;'
618
  f'height:10px;overflow:hidden;">'
@@ -620,11 +819,11 @@ def _get_vram_html() -> str:
620
  f'box-shadow:0 0 6px {bar_color};transition:width 0.5s ease;"></div></div>'
621
  f'<div style="display:flex;justify-content:space-between;font-size:0.6rem;'
622
  f'color:#333;margin-top:1px;">'
623
- f'<span style="color:#4a5568;">reserved: {reserved:.1f} GB</span></div>'
624
  f'</div>'
625
  )
626
  except Exception:
627
- return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">VRAM: unavailable</div>'
628
 
629
 
630
  # ---------------------------------------------------------------------------
@@ -1067,8 +1266,7 @@ def benchmark(
1067
  pass
1068
  pipeline_ref[0] = None
1069
  gc.collect()
1070
- if torch.cuda.is_available():
1071
- torch.cuda.empty_cache()
1072
 
1073
  yield (
1074
  f"**{method_key} complete** ({mi + 1}/{len(methods_to_test)}) \u2014 {_bench_elapsed()}",
@@ -1418,8 +1616,7 @@ def benchmark_multi_model(
1418
  pass
1419
  pipeline_ref[0] = None
1420
  gc.collect()
1421
- if torch.cuda.is_available():
1422
- torch.cuda.empty_cache()
1423
 
1424
  yield (
1425
  f"**{model_id} complete** ({mi + 1}/{len(model_choices)}) \u2014 {_mm_elapsed()}",
@@ -1518,7 +1715,6 @@ def _format_multi_model_results(results: list[dict], context: dict | None = None
1518
 
1519
  @spaces.GPU(duration=300)
1520
  def obliterate(model_choice: str, method_choice: str,
1521
- hub_auto_push: bool, hub_repo: str,
1522
  prompt_volume_choice: str, dataset_source_choice: str,
1523
  custom_harmful: str, custom_harmless: str,
1524
  # Advanced params (sliders)
@@ -1551,14 +1747,6 @@ def obliterate(model_choice: str, method_choice: str,
1551
  model_id = MODELS.get(model_choice, model_choice)
1552
  is_preset = model_choice in MODELS
1553
  method = METHODS.get(method_choice, "advanced")
1554
- # Resolve push-to-hub: explicit repo overrides auto-naming
1555
- _hub_override = hub_repo.strip() if hub_repo and hub_repo.strip() else None
1556
- if _hub_override:
1557
- push_to_hub = _hub_override
1558
- elif hub_auto_push:
1559
- push_to_hub = "auto" # resolved to {user}/{model}-OBLITERATED at push time
1560
- else:
1561
- push_to_hub = None
1562
  prompt_volume = PROMPT_VOLUMES.get(prompt_volume_choice, 33)
1563
 
1564
  # Resolve "adaptive" → telemetry-recommended method for this model
@@ -1606,26 +1794,6 @@ def obliterate(model_choice: str, method_choice: str,
1606
  )
1607
  return
1608
 
1609
- # Early validation: Hub repo format + token availability
1610
- # Resolve which token to use: user's own HF_TOKEN, or the shared community token.
1611
- _user_token = os.environ.get("HF_TOKEN")
1612
- _hub_token = _user_token or _HUB_COMMUNITY_TOKEN
1613
- _hub_org = None if _user_token else _HUB_COMMUNITY_ORG # community org only when using shared token
1614
- if push_to_hub:
1615
- if push_to_hub != "auto" and not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', push_to_hub):
1616
- yield (
1617
- "**Error:** Invalid Hub repo format. Use `username/model-name`.",
1618
- "", gr.update(), gr.update(), gr.update(), gr.update(),
1619
- )
1620
- return
1621
- if not _hub_token:
1622
- yield (
1623
- "**Error:** No Hub token available. Set HF_TOKEN or OBLITERATUS_HUB_TOKEN "
1624
- "as an environment variable or Space secret.",
1625
- "", gr.update(), gr.update(), gr.update(), gr.update(),
1626
- )
1627
- return
1628
-
1629
  # Resolve dataset source — custom prompts override the dropdown
1630
  use_custom = custom_harmful and custom_harmful.strip()
1631
  dataset_key = get_source_key_from_label(dataset_source_choice) if dataset_source_choice else "builtin"
@@ -1699,9 +1867,6 @@ def obliterate(model_choice: str, method_choice: str,
1699
  output_dir=save_dir,
1700
  device="auto",
1701
  dtype="float16",
1702
- push_to_hub=push_to_hub,
1703
- hub_token=_hub_token,
1704
- hub_community_org=_hub_org,
1705
  quantization=quantization,
1706
  trust_remote_code=is_preset,
1707
  harmful_prompts=harmful_all[:n],
@@ -1719,9 +1884,6 @@ def obliterate(model_choice: str, method_choice: str,
1719
  device="auto",
1720
  dtype="float16",
1721
  method=method,
1722
- push_to_hub=push_to_hub,
1723
- hub_token=_hub_token,
1724
- hub_community_org=_hub_org,
1725
  quantization=quantization,
1726
  trust_remote_code=is_preset,
1727
  harmful_prompts=harmful_all[:n],
@@ -1774,12 +1936,6 @@ def obliterate(model_choice: str, method_choice: str,
1774
  log_lines.append(f"Dataset: {source_label}")
1775
  vol_label = "all" if prompt_volume == -1 else str(prompt_volume)
1776
  log_lines.append(f"Prompt volume: {vol_label} pairs")
1777
- if push_to_hub:
1778
- if push_to_hub == "auto":
1779
- _ns = _hub_org or "{you}"
1780
- log_lines.append(f"Push to Hub: auto ({_ns}/{{model}}-OBLITERATED)")
1781
- else:
1782
- log_lines.append(f"Push to Hub: {push_to_hub}")
1783
  if quantization:
1784
  log_lines.append(f"Quantization: {quantization} (auto-detected for GPU fit)")
1785
  log_lines.append("")
@@ -2118,11 +2274,11 @@ def chat_respond(message: str, history: list[dict], system_prompt: str,
2118
  _needs_reload = model is None or tokenizer is None
2119
  if not _needs_reload:
2120
  try:
2121
- dev = next(model.parameters()).device
2122
- if dev.type == "meta":
2123
  _needs_reload = True
2124
- elif torch.cuda.is_available() and dev.type != "cuda":
2125
- model.to("cuda")
2126
  except Exception:
2127
  _needs_reload = True
2128
 
@@ -2552,11 +2708,11 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
2552
  _needs_reload = abliterated_model is None or tokenizer is None
2553
  if not _needs_reload:
2554
  try:
2555
- dev = next(abliterated_model.parameters()).device
2556
- if dev.type == "meta":
2557
  _needs_reload = True
2558
- elif torch.cuda.is_available() and dev.type != "cuda":
2559
- abliterated_model.to("cuda")
2560
  except Exception:
2561
  _needs_reload = True
2562
 
@@ -2689,8 +2845,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
2689
  abl_device = next(abliterated_model.parameters()).device
2690
  abliterated_model.to("cpu")
2691
  gc.collect()
2692
- if torch.cuda.is_available():
2693
- torch.cuda.empty_cache()
2694
 
2695
  model_id = MODELS.get(model_name, model_name)
2696
  # Only trust remote code for known preset models, not arbitrary user-supplied IDs
@@ -2742,8 +2897,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
2742
  # Free the original model
2743
  del original_model
2744
  gc.collect()
2745
- if torch.cuda.is_available():
2746
- torch.cuda.empty_cache()
2747
 
2748
  except Exception as e:
2749
  original_response = f"*Could not load original model for comparison: {e}*"
@@ -2752,7 +2906,7 @@ def ab_chat_respond(message: str, history_left: list[dict], history_right: list[
2752
  # Use torch.device("cuda") rather than the captured abl_device, since
2753
  # on ZeroGPU the original device reference may point to a stale context.
2754
  try:
2755
- restore_device = torch.device("cuda") if torch.cuda.is_available() else abl_device
2756
  abliterated_model.to(restore_device)
2757
  except Exception:
2758
  pass # If GPU restore fails, model stays on CPU (still usable)
@@ -2870,8 +3024,7 @@ def strength_sweep(model_choice: str, method_choice: str,
2870
 
2871
  # Cleanup between runs
2872
  gc.collect()
2873
- if torch.cuda.is_available():
2874
- torch.cuda.empty_cache()
2875
 
2876
  # Generate dose-response curve
2877
  gallery = None
@@ -2963,6 +3116,117 @@ def _format_sweep_results(results: list[dict]) -> str:
2963
  return "\n".join(lines)
2964
 
2965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2966
  # ---------------------------------------------------------------------------
2967
  # Export Research Artifacts
2968
  # ---------------------------------------------------------------------------
@@ -3523,20 +3787,10 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
3523
  lines=5,
3524
  )
3525
 
3526
- with gr.Row():
3527
- hub_auto_push = gr.Checkbox(
3528
- label="Auto-push to Hub",
3529
- value=False,
3530
- info=f"Pushes your model to {_HUB_COMMUNITY_ORG}/{{model}}-OBLITERATED on HF Hub. "
3531
- "No token needed — works out of the box!",
3532
- )
3533
- hub_repo = gr.Textbox(
3534
- label="Push to Hub (optional override)",
3535
- placeholder="auto-filled when checkbox is ticked, or type your own",
3536
- info="Leave blank with checkbox ticked for auto-naming, "
3537
- "or enter a custom repo ID (e.g. your-username/my-model).",
3538
- )
3539
- hub_warning_md = gr.Markdown("")
3540
 
3541
  # ── Advanced Settings (auto-populated from method preset) ────
3542
  _defaults = _get_preset_defaults("advanced (recommended)")
@@ -4168,33 +4422,19 @@ tradeoff point where refusal is minimized with minimal capability damage.
4168
  with gr.Tab("Tourney", id="tourney"):
4169
  gr.Markdown("""### March Madness Tournament
4170
  Pit **all abliteration methods** against each other in elimination rounds.
4171
- The winner gets auto-pushed to HuggingFace Hub.
4172
 
4173
  **Round 1 — Qualifiers:** All methods, reduced prompts. Bottom half eliminated.
4174
  **Round 2 — Semifinals:** Survivors, full prompts. Bottom half eliminated.
4175
  **Round 3 — Finals:** Top contenders, maximum prompts. Champion crowned.
4176
  """)
4177
- with gr.Row():
4178
- with gr.Column(scale=2):
4179
- tourney_model_dd = gr.Dropdown(
4180
- choices=list(MODELS.keys()),
4181
- value="Alibaba (Qwen) / Qwen3-4B",
4182
- label="Target Model",
4183
- info="Select a model to tournament-abliterate",
4184
- allow_custom_value=True,
4185
- )
4186
- with gr.Column(scale=1):
4187
- tourney_hub_org = gr.Textbox(
4188
- label="HF Hub Org (optional)",
4189
- placeholder="my-org",
4190
- info="Push winner to hub-org/model-name-OBLITERATED",
4191
- )
4192
- with gr.Column(scale=1):
4193
- tourney_hub_repo = gr.Textbox(
4194
- label="HF Hub Repo (optional)",
4195
- placeholder="org/repo-name",
4196
- info="Overrides Hub Org — full repo ID",
4197
- )
4198
 
4199
  with gr.Accordion("Advanced Settings", open=False):
4200
  with gr.Row():
@@ -4223,97 +4463,12 @@ The winner gets auto-pushed to HuggingFace Hub.
4223
  interactive=False,
4224
  )
4225
 
4226
- @tourney_btn.click(
4227
- inputs=[tourney_model_dd, tourney_hub_org, tourney_hub_repo,
 
4228
  tourney_dataset_dd, tourney_quant_dd],
4229
  outputs=[tourney_status, tourney_bracket, tourney_log],
4230
  )
4231
- def run_tourney(model_choice, hub_org, hub_repo, dataset, quantization):
4232
- if not model_choice or not model_choice.strip():
4233
- yield "**Error:** Select a model first.", "", ""
4234
- return
4235
-
4236
- from obliteratus.tourney import TourneyRunner, render_bracket
4237
-
4238
- # Resolve display label → HuggingFace model ID
4239
- model_id = model_choice.strip()
4240
- if model_id in MODELS:
4241
- model_id = MODELS[model_id]
4242
-
4243
- hub_org_val = hub_org.strip() if hub_org and hub_org.strip() else None
4244
- hub_repo_val = hub_repo.strip() if hub_repo and hub_repo.strip() else None
4245
- quant = quantization if quantization != "none" else None
4246
-
4247
- log_lines = []
4248
-
4249
- def on_log(msg):
4250
- log_lines.append(msg)
4251
-
4252
- def on_round(rnd):
4253
- pass # logged via on_log
4254
-
4255
- dataset_key = get_source_key_from_label(dataset) if dataset else "builtin"
4256
-
4257
- runner = TourneyRunner(
4258
- model_name=model_id,
4259
- hub_org=hub_org_val,
4260
- hub_repo=hub_repo_val,
4261
- dataset_key=dataset_key,
4262
- quantization=quant,
4263
- on_log=on_log,
4264
- on_round=on_round,
4265
- )
4266
-
4267
- # Yield progress updates during tournament
4268
- import threading
4269
- result_ref = [None]
4270
- error_ref = [None]
4271
-
4272
- def _run():
4273
- try:
4274
- result_ref[0] = runner.run()
4275
- except Exception as e:
4276
- error_ref[0] = e
4277
-
4278
- thread = threading.Thread(target=_run, daemon=True)
4279
- thread.start()
4280
-
4281
- while thread.is_alive():
4282
- yield (
4283
- "**Tournament in progress...**",
4284
- "",
4285
- "\n".join(log_lines[-100:]),
4286
- )
4287
- time.sleep(1.0)
4288
-
4289
- thread.join(timeout=10)
4290
-
4291
- if error_ref[0]:
4292
- yield (
4293
- f"**Error:** {error_ref[0]}",
4294
- "",
4295
- "\n".join(log_lines),
4296
- )
4297
- return
4298
-
4299
- result = result_ref[0]
4300
- if result and result.winner:
4301
- bracket_md = render_bracket(result)
4302
- hub_msg = ""
4303
- if result.hub_repo:
4304
- hub_msg = f"\nPushed to [{result.hub_repo}](https://huggingface.co/{result.hub_repo})"
4305
- yield (
4306
- f"**Champion: `{result.winner.method}`** "
4307
- f"(score: {result.winner.score:.4f}){hub_msg}",
4308
- bracket_md,
4309
- "\n".join(log_lines),
4310
- )
4311
- else:
4312
- yield (
4313
- "**Tournament complete** — no winner determined.",
4314
- "",
4315
- "\n".join(log_lines),
4316
- )
4317
 
4318
  # ── Tab 7: Export ────────────────────────────────────────────────���
4319
  with gr.Tab("Export", id="export"):
@@ -4336,7 +4491,94 @@ Download all intermediate data from your last obliteration run as a ZIP archive.
4336
  outputs=[export_file, export_status],
4337
  )
4338
 
4339
- # ── Tab 7: Leaderboard ────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4340
  with gr.Tab("Leaderboard", id="leaderboard"):
4341
  gr.Markdown("""### Community Leaderboard
4342
  All benchmark results from **every OBLITERATUS Space** (including duplicated copies) are
@@ -4562,12 +4804,6 @@ Built on the shoulders of:
4562
  outputs=[prompt_vol_dd, dataset_info_md],
4563
  )
4564
 
4565
- # Wire hub repo → live validation
4566
- hub_repo.change(
4567
- fn=_validate_hub_repo,
4568
- inputs=[hub_repo],
4569
- outputs=[hub_warning_md],
4570
- )
4571
 
4572
  # Wire benchmark → Chat/A/B cross-tab dropdown updates
4573
  bench_btn.click(
@@ -4616,7 +4852,7 @@ Built on the shoulders of:
4616
  # may not fire after generator teardown.
4617
  obliterate_btn.click(
4618
  fn=obliterate,
4619
- inputs=[model_dd, method_dd, hub_auto_push, hub_repo, prompt_vol_dd, dataset_dd,
4620
  custom_harmful_tb, custom_harmless_tb] + _adv_controls,
4621
  outputs=[status_md, log_box, chat_status, session_model_dd, metrics_md, ab_session_model_dd],
4622
  ).then(
 
57
 
58
  import gradio as gr
59
  import torch
60
+ from obliteratus import device as dev
61
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
62
 
63
  # ── ZeroGPU support ─────────────────────────────────────────────────
 
400
  return ""
401
 
402
 
403
+ # ---------------------------------------------------------------------------
404
+ # Push to Hub — dedicated tab backend
405
+ # ---------------------------------------------------------------------------
406
+
407
+ def _generate_model_card(meta: dict) -> str:
408
+ """Generate a HuggingFace model card README for a session model."""
409
+ model_id = meta.get("model_id", "unknown")
410
+ method = meta.get("method", "unknown")
411
+ source = meta.get("source", "obliterate")
412
+ short_model = model_id.split("/")[-1] if "/" in model_id else model_id
413
+
414
+ metrics_table = ""
415
+ tourney_metrics = meta.get("tourney_metrics")
416
+ if tourney_metrics:
417
+ rows = "\n".join(
418
+ f"| {k.replace('_', ' ').title()} | {v:.4f} |"
419
+ for k, v in tourney_metrics.items() if isinstance(v, (int, float))
420
+ )
421
+ metrics_table = f"\n## Metrics\n\n| Metric | Value |\n|--------|-------|\n{rows}\n"
422
+
423
+ return f"""---
424
+ language: en
425
+ tags:
426
+ - obliteratus
427
+ - abliteration
428
+ - uncensored
429
+ - {source}
430
+ base_model: {model_id}
431
+ ---
432
+
433
+ # {short_model}-OBLITERATED
434
+
435
+ This model was abliterated using the **`{method}`** method via
436
+ [OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
437
+
438
+ | Detail | Value |
439
+ |--------|-------|
440
+ | Base model | `{model_id}` |
441
+ | Method | `{method}` |
442
+ | Source | {source} |
443
+ {metrics_table}
444
+ ## How to Use
445
+
446
+ ```python
447
+ from transformers import AutoModelForCausalLM, AutoTokenizer
448
+
449
+ model = AutoModelForCausalLM.from_pretrained("{short_model}-OBLITERATED")
450
+ tokenizer = AutoTokenizer.from_pretrained("{short_model}-OBLITERATED")
451
+
452
+ prompt = "Hello, how are you?"
453
+ inputs = tokenizer(prompt, return_tensors="pt")
454
+ outputs = model.generate(**inputs, max_new_tokens=256)
455
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
456
+ ```
457
+
458
+ ## About OBLITERATUS
459
+
460
+ OBLITERATUS is an open-source tool for removing refusal behavior from language
461
+ models via activation engineering (abliteration). Learn more at
462
+ [github.com/elder-plinius/OBLITERATUS](https://github.com/elder-plinius/OBLITERATUS).
463
+ """
464
+
465
+
466
+ def _get_hub_session_info(label: str) -> str:
467
+ """Return a markdown summary of the selected session model."""
468
+ if not label or label.startswith("("):
469
+ return ""
470
+ meta = _session_models.get(label)
471
+ if not meta:
472
+ return "*Session model not found — try refreshing the list.*"
473
+ lines = [
474
+ f"**Model:** `{meta.get('model_id', 'unknown')}`",
475
+ f"**Method:** `{meta.get('method', 'unknown')}`",
476
+ f"**Source:** {meta.get('source', 'unknown')}",
477
+ f"**Path:** `{meta.get('output_dir', 'N/A')}`",
478
+ ]
479
+ score = meta.get("tourney_score")
480
+ if score is not None:
481
+ lines.append(f"**Tourney score:** {score:.4f}")
482
+ return "\n".join(lines)
483
+
484
+
485
+ def _auto_hub_repo_id(label: str) -> str:
486
+ """Generate an auto-filled Hub repo ID for the selected session model."""
487
+ meta = _session_models.get(label)
488
+ if not meta:
489
+ return ""
490
+ model_id = meta.get("model_id", "")
491
+ import re
492
+ short = model_id.split("/")[-1] if "/" in model_id else model_id
493
+ short = re.sub(r"[^a-zA-Z0-9\-.]", "-", short)
494
+ return f"{_HUB_COMMUNITY_ORG}/{short}-OBLITERATED"
495
+
496
+
497
+ def push_session_to_hub(
498
+ session_label: str,
499
+ hub_repo_id: str,
500
+ hub_token_input: str,
501
+ refine_enabled: bool,
502
+ refine_regularization: float,
503
+ refine_passes: int,
504
+ progress=gr.Progress(),
505
+ ):
506
+ """Push a session model to HuggingFace Hub, with optional refinement."""
507
+ import os
508
+ import re
509
+
510
+ if not session_label or session_label.startswith("("):
511
+ yield "**Error:** Select a session model first.", ""
512
+ return
513
+
514
+ meta = _session_models.get(session_label)
515
+ if not meta:
516
+ yield "**Error:** Session model not found. Try refreshing the list.", ""
517
+ return
518
+
519
+ output_dir = meta.get("output_dir", "")
520
+ if not output_dir or not Path(output_dir).exists():
521
+ yield f"**Error:** Model directory not found: `{output_dir}`", ""
522
+ return
523
+
524
+ # Resolve repo ID
525
+ repo_id = hub_repo_id.strip() if hub_repo_id else ""
526
+ if not repo_id:
527
+ repo_id = _auto_hub_repo_id(session_label)
528
+ if not repo_id:
529
+ yield "**Error:** Could not determine Hub repo ID.", ""
530
+ return
531
+ if not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+$', repo_id):
532
+ yield "**Error:** Invalid repo format. Use `username/model-name`.", ""
533
+ return
534
+
535
+ # Resolve token
536
+ token = hub_token_input.strip() if hub_token_input else None
537
+ if not token:
538
+ token = os.environ.get("HF_TOKEN") or _HUB_COMMUNITY_TOKEN
539
+ if not token:
540
+ yield (
541
+ "**Error:** No Hub token available. Enter a token above, "
542
+ "or set `HF_TOKEN` / `OBLITERATUS_HUB_TOKEN` as an environment variable.",
543
+ "",
544
+ )
545
+ return
546
+
547
+ # Optional refinement pass
548
+ if refine_enabled and refine_passes > 0:
549
+ progress(0.1, desc="Refining model...")
550
+ yield "Applying refinement passes...", ""
551
+ try:
552
+ from obliteratus.abliterate import AbliterationPipeline
553
+ from obliteratus.prompts import load_dataset_source
554
+
555
+ dataset_key = meta.get("dataset_key", "builtin")
556
+ if dataset_key == "custom":
557
+ dataset_key = "builtin"
558
+ harmful, harmless = load_dataset_source(dataset_key)
559
+ n = min(33, len(harmful), len(harmless))
560
+
561
+ pipeline = AbliterationPipeline(
562
+ model_name=output_dir, # load from saved checkpoint
563
+ output_dir=output_dir,
564
+ device="auto",
565
+ dtype="float16",
566
+ method=meta.get("method", "advanced"),
567
+ regularization=refine_regularization,
568
+ refinement_passes=refine_passes,
569
+ harmful_prompts=harmful[:n],
570
+ harmless_prompts=harmless[:n],
571
+ )
572
+ pipeline.run()
573
+ except Exception as e:
574
+ yield f"**Refinement failed:** {e}", ""
575
+ return
576
+
577
+ # Generate model card
578
+ progress(0.5, desc="Generating model card...")
579
+ yield f"Generating model card and uploading to `{repo_id}`...", ""
580
+ card_content = _generate_model_card(meta)
581
+ card_path = Path(output_dir) / "README.md"
582
+ card_path.write_text(card_content)
583
+
584
+ # Upload to Hub
585
+ progress(0.6, desc="Uploading to Hub...")
586
+ try:
587
+ from huggingface_hub import HfApi
588
+ api = HfApi(token=token)
589
+ api.create_repo(repo_id, exist_ok=True)
590
+
591
+ method = meta.get("method", "unknown")
592
+ model_id = meta.get("model_id", "unknown")
593
+ api.upload_folder(
594
+ folder_path=output_dir,
595
+ repo_id=repo_id,
596
+ commit_message=f"OBLITERATUS: {method} on {model_id}",
597
+ )
598
+ except Exception as e:
599
+ yield f"**Upload failed:** {e}", ""
600
+ return
601
+
602
+ progress(1.0, desc="Done!")
603
+ hub_url = f"https://huggingface.co/{repo_id}"
604
+ yield (
605
+ f"**Pushed successfully to [{repo_id}]({hub_url})**",
606
+ f"[Open on HuggingFace Hub]({hub_url})",
607
+ )
608
+
609
+
610
  PROMPT_VOLUMES = {
611
  "33 (fast)": 33,
612
  "66 (better signal)": 66,
 
655
  # ---------------------------------------------------------------------------
656
 
657
  def _clear_gpu():
658
+ """Free GPU/accelerator memory. Resilient to device errors."""
659
  with _lock:
660
  _state["model"] = None
661
  _state["tokenizer"] = None
662
+ dev.free_gpu_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
 
665
  def _install_steering_hooks(model, steering_meta: dict) -> int:
 
783
  # ---------------------------------------------------------------------------
784
 
785
  def _get_vram_html() -> str:
786
+ """Return an HTML snippet showing GPU/accelerator memory usage as a styled bar."""
787
+ if not dev.is_gpu_available():
788
  return (
789
  '<div style="text-align:center;color:#4a5568;font-size:0.72rem;'
790
  'letter-spacing:1px;margin-top:6px;">CPU ONLY — NO GPU DETECTED</div>'
791
  )
792
  try:
793
+ mem = dev.get_memory_info()
794
+ used = mem.used_gb
795
+ total = mem.total_gb
796
  pct = (used / total * 100) if total > 0 else 0
797
  # Color shifts from green → yellow → red
798
  if pct < 50:
 
801
  bar_color = "#ffcc00"
802
  else:
803
  bar_color = "#ff003c"
804
+ device_name = mem.device_name
805
+ reserved_html = (
806
+ f'<span style="color:#4a5568;">reserved: {mem.reserved_gb:.1f} GB</span>'
807
+ if mem.reserved_gb > 0
808
+ else f'<span style="color:#4a5568;">unified memory</span>'
809
+ )
810
  return (
811
  f'<div style="margin:6px auto 0;max-width:480px;">'
812
  f'<div style="display:flex;justify-content:space-between;font-size:0.68rem;'
813
  f'color:#4a5568;letter-spacing:1px;margin-bottom:2px;">'
814
+ f'<span>{device_name}</span>'
815
  f'<span>{used:.1f} / {total:.1f} GB ({pct:.0f}%)</span></div>'
816
  f'<div style="background:#0a0a0f;border:1px solid #1a1f2e;border-radius:3px;'
817
  f'height:10px;overflow:hidden;">'
 
819
  f'box-shadow:0 0 6px {bar_color};transition:width 0.5s ease;"></div></div>'
820
  f'<div style="display:flex;justify-content:space-between;font-size:0.6rem;'
821
  f'color:#333;margin-top:1px;">'
822
+ f'{reserved_html}</div>'
823
  f'</div>'
824
  )
825
  except Exception:
826
+ return '<div style="text-align:center;color:#4a5568;font-size:0.72rem;">Memory: unavailable</div>'
827
 
828
 
829
  # ---------------------------------------------------------------------------
 
1266
  pass
1267
  pipeline_ref[0] = None
1268
  gc.collect()
1269
+ dev.empty_cache()
 
1270
 
1271
  yield (
1272
  f"**{method_key} complete** ({mi + 1}/{len(methods_to_test)}) \u2014 {_bench_elapsed()}",
 
1616
  pass
1617
  pipeline_ref[0] = None
1618
  gc.collect()
1619
+ dev.empty_cache()
 
1620
 
1621
  yield (
1622
  f"**{model_id} complete** ({mi + 1}/{len(model_choices)}) \u2014 {_mm_elapsed()}",
 
1715
 
1716
  @spaces.GPU(duration=300)
1717
  def obliterate(model_choice: str, method_choice: str,
 
1718
  prompt_volume_choice: str, dataset_source_choice: str,
1719
  custom_harmful: str, custom_harmless: str,
1720
  # Advanced params (sliders)
 
1747
  model_id = MODELS.get(model_choice, model_choice)
1748
  is_preset = model_choice in MODELS
1749
  method = METHODS.get(method_choice, "advanced")
 
 
 
 
 
 
 
 
1750
  prompt_volume = PROMPT_VOLUMES.get(prompt_volume_choice, 33)
1751
 
1752
  # Resolve "adaptive" → telemetry-recommended method for this model
 
1794
  )
1795
  return
1796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1797
  # Resolve dataset source — custom prompts override the dropdown
1798
  use_custom = custom_harmful and custom_harmful.strip()
1799
  dataset_key = get_source_key_from_label(dataset_source_choice) if dataset_source_choice else "builtin"
 
1867
  output_dir=save_dir,
1868
  device="auto",
1869
  dtype="float16",
 
 
 
1870
  quantization=quantization,
1871
  trust_remote_code=is_preset,
1872
  harmful_prompts=harmful_all[:n],
 
1884
  device="auto",
1885
  dtype="float16",
1886
  method=method,
 
 
 
1887
  quantization=quantization,
1888
  trust_remote_code=is_preset,
1889
  harmful_prompts=harmful_all[:n],
 
1936
  log_lines.append(f"Dataset: {source_label}")
1937
  vol_label = "all" if prompt_volume == -1 else str(prompt_volume)
1938
  log_lines.append(f"Prompt volume: {vol_label} pairs")
 
 
 
 
 
 
1939
  if quantization:
1940
  log_lines.append(f"Quantization: {quantization} (auto-detected for GPU fit)")
1941
  log_lines.append("")
 
2274
  _needs_reload = model is None or tokenizer is None
2275
  if not _needs_reload:
2276
  try:
2277
+ model_dev = next(model.parameters()).device
2278
+ if model_dev.type == "meta":
2279
  _needs_reload = True
2280
+ elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
2281
+ model.to(dev.get_device())
2282
  except Exception:
2283
  _needs_reload = True
2284
 
 
2708
  _needs_reload = abliterated_model is None or tokenizer is None
2709
  if not _needs_reload:
2710
  try:
2711
+ model_dev = next(abliterated_model.parameters()).device
2712
+ if model_dev.type == "meta":
2713
  _needs_reload = True
2714
+ elif dev.is_gpu_available() and model_dev.type not in ("cuda", "mps"):
2715
+ abliterated_model.to(dev.get_device())
2716
  except Exception:
2717
  _needs_reload = True
2718
 
 
2845
  abl_device = next(abliterated_model.parameters()).device
2846
  abliterated_model.to("cpu")
2847
  gc.collect()
2848
+ dev.empty_cache()
 
2849
 
2850
  model_id = MODELS.get(model_name, model_name)
2851
  # Only trust remote code for known preset models, not arbitrary user-supplied IDs
 
2897
  # Free the original model
2898
  del original_model
2899
  gc.collect()
2900
+ dev.empty_cache()
 
2901
 
2902
  except Exception as e:
2903
  original_response = f"*Could not load original model for comparison: {e}*"
 
2906
  # Use torch.device("cuda") rather than the captured abl_device, since
2907
  # on ZeroGPU the original device reference may point to a stale context.
2908
  try:
2909
+ restore_device = torch.device(dev.get_device()) if dev.is_gpu_available() else abl_device
2910
  abliterated_model.to(restore_device)
2911
  except Exception:
2912
  pass # If GPU restore fails, model stays on CPU (still usable)
 
3024
 
3025
  # Cleanup between runs
3026
  gc.collect()
3027
+ dev.empty_cache()
 
3028
 
3029
  # Generate dose-response curve
3030
  gallery = None
 
3116
  return "\n".join(lines)
3117
 
3118
 
3119
+ # ---------------------------------------------------------------------------
3120
+ # Tournament
3121
+ # ---------------------------------------------------------------------------
3122
+
3123
+ @spaces.GPU(duration=300)
3124
+ def run_tourney(model_choice, dataset, quantization):
3125
+ """Run an elimination tournament across all abliteration methods.
3126
+
3127
+ On ZeroGPU Spaces the @spaces.GPU decorator allocates a GPU for the
3128
+ duration of the tournament (up to 5 minutes).
3129
+ """
3130
+ if not model_choice or not model_choice.strip():
3131
+ yield "**Error:** Select a model first.", "", ""
3132
+ return
3133
+
3134
+ from obliteratus.tourney import TourneyRunner, render_bracket
3135
+
3136
+ # Resolve display label → HuggingFace model ID
3137
+ model_id = model_choice.strip()
3138
+ if model_id in MODELS:
3139
+ model_id = MODELS[model_id]
3140
+
3141
+ quant = quantization if quantization != "none" else None
3142
+
3143
+ log_lines = []
3144
+
3145
+ def on_log(msg):
3146
+ log_lines.append(msg)
3147
+
3148
+ def on_round(rnd):
3149
+ pass # logged via on_log
3150
+
3151
+ dataset_key = get_source_key_from_label(dataset) if dataset else "builtin"
3152
+
3153
+ runner = TourneyRunner(
3154
+ model_name=model_id,
3155
+ hub_org=None,
3156
+ hub_repo=None,
3157
+ dataset_key=dataset_key,
3158
+ quantization=quant,
3159
+ on_log=on_log,
3160
+ on_round=on_round,
3161
+ )
3162
+
3163
+ # Run tournament in a background thread so we can yield progress
3164
+ import threading
3165
+ result_ref = [None]
3166
+ error_ref = [None]
3167
+
3168
+ def _run():
3169
+ try:
3170
+ result_ref[0] = runner.run()
3171
+ except Exception as e:
3172
+ error_ref[0] = e
3173
+
3174
+ thread = threading.Thread(target=_run, daemon=True)
3175
+ thread.start()
3176
+
3177
+ while thread.is_alive():
3178
+ yield (
3179
+ "**Tournament in progress...**",
3180
+ "",
3181
+ "\n".join(log_lines[-100:]),
3182
+ )
3183
+ time.sleep(1.0)
3184
+
3185
+ thread.join(timeout=10)
3186
+
3187
+ if error_ref[0]:
3188
+ yield (
3189
+ f"**Error:** {error_ref[0]}",
3190
+ "",
3191
+ "\n".join(log_lines),
3192
+ )
3193
+ return
3194
+
3195
+ result = result_ref[0]
3196
+ if result and result.winner:
3197
+ bracket_md = render_bracket(result)
3198
+ # Register winner in session models for Push to Hub tab
3199
+ if result.winner.output_dir:
3200
+ _ts = datetime.now().strftime("%H:%M")
3201
+ _short = model_id.split("/")[-1] if "/" in model_id else model_id
3202
+ _label = f"tourney winner ({result.winner.method}) on {_short} ({_ts})"
3203
+ with _lock:
3204
+ _session_models[_label] = {
3205
+ "model_id": model_id,
3206
+ "model_choice": model_choice,
3207
+ "method": result.winner.method,
3208
+ "dataset_key": dataset_key,
3209
+ "prompt_volume": 0,
3210
+ "output_dir": result.winner.output_dir,
3211
+ "source": "tourney",
3212
+ "tourney_score": result.winner.score,
3213
+ "tourney_metrics": result.winner.metrics,
3214
+ }
3215
+ yield (
3216
+ f"**Champion: `{result.winner.method}`** "
3217
+ f"(score: {result.winner.score:.4f})\n"
3218
+ f"Push it to HuggingFace Hub from the **Push to Hub** tab.",
3219
+ bracket_md,
3220
+ "\n".join(log_lines),
3221
+ )
3222
+ else:
3223
+ yield (
3224
+ "**Tournament complete** — no winner determined.",
3225
+ "",
3226
+ "\n".join(log_lines),
3227
+ )
3228
+
3229
+
3230
  # ---------------------------------------------------------------------------
3231
  # Export Research Artifacts
3232
  # ---------------------------------------------------------------------------
 
3787
  lines=5,
3788
  )
3789
 
3790
+ gr.Markdown(
3791
+ "*After obliterating, push your model to HuggingFace Hub from the **Push to Hub** tab.*",
3792
+ elem_classes=["hub-hint"],
3793
+ )
 
 
 
 
 
 
 
 
 
 
3794
 
3795
  # ── Advanced Settings (auto-populated from method preset) ────
3796
  _defaults = _get_preset_defaults("advanced (recommended)")
 
4422
  with gr.Tab("Tourney", id="tourney"):
4423
  gr.Markdown("""### March Madness Tournament
4424
  Pit **all abliteration methods** against each other in elimination rounds.
4425
+ The winner is saved locally — push it to HuggingFace Hub from the **Push to Hub** tab.
4426
 
4427
  **Round 1 — Qualifiers:** All methods, reduced prompts. Bottom half eliminated.
4428
  **Round 2 — Semifinals:** Survivors, full prompts. Bottom half eliminated.
4429
  **Round 3 — Finals:** Top contenders, maximum prompts. Champion crowned.
4430
  """)
4431
+ tourney_model_dd = gr.Dropdown(
4432
+ choices=list(MODELS.keys()),
4433
+ value="Alibaba (Qwen) / Qwen3-4B",
4434
+ label="Target Model",
4435
+ info="Select a model to tournament-abliterate",
4436
+ allow_custom_value=True,
4437
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4438
 
4439
  with gr.Accordion("Advanced Settings", open=False):
4440
  with gr.Row():
 
4463
  interactive=False,
4464
  )
4465
 
4466
+ tourney_btn.click(
4467
+ fn=run_tourney,
4468
+ inputs=[tourney_model_dd,
4469
  tourney_dataset_dd, tourney_quant_dd],
4470
  outputs=[tourney_status, tourney_bracket, tourney_log],
4471
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4472
 
4473
  # ── Tab 7: Export ────────────────────────────────────────────────���
4474
  with gr.Tab("Export", id="export"):
 
4491
  outputs=[export_file, export_status],
4492
  )
4493
 
4494
+ # ── Tab: Push to Hub ──────────────────────────────────────────────
4495
+ with gr.Tab("Push to Hub", id="push_hub"):
4496
+ gr.Markdown("""### Push to HuggingFace Hub
4497
+ Select any session model from your Obliterate, Benchmark, or Tourney runs,
4498
+ optionally apply a quick refinement pass, then push to HuggingFace Hub
4499
+ with the **-OBLITERATED** tag.
4500
+ """)
4501
+
4502
+ with gr.Row():
4503
+ with gr.Column(scale=2):
4504
+ push_session_dd = gr.Dropdown(
4505
+ choices=_get_session_model_choices(),
4506
+ label="Session Model",
4507
+ info="Pick a model from any tab's output",
4508
+ )
4509
+ push_refresh_btn = gr.Button("Refresh List", variant="secondary", size="sm")
4510
+ push_model_info = gr.Markdown("")
4511
+
4512
+ with gr.Column(scale=1):
4513
+ push_repo_id = gr.Textbox(
4514
+ label="Hub Repo ID",
4515
+ placeholder="auto-filled, or type your own",
4516
+ info="e.g. my-org/my-model-OBLITERATED",
4517
+ )
4518
+ push_token = gr.Textbox(
4519
+ label="HF Token (optional)",
4520
+ placeholder="hf_...",
4521
+ type="password",
4522
+ info="Leave blank to use HF_TOKEN env var or community token",
4523
+ )
4524
+ push_repo_warning = gr.Markdown("")
4525
+
4526
+ with gr.Accordion("Quick Refiner (optional)", open=False):
4527
+ gr.Markdown(
4528
+ "*Optionally apply extra refinement passes to your model before pushing. "
4529
+ "This re-runs the abliteration pipeline with adjusted regularization.*"
4530
+ )
4531
+ with gr.Row():
4532
+ push_refine_reg = gr.Slider(
4533
+ 0.0, 1.0, value=0.1, step=0.05,
4534
+ label="Regularization",
4535
+ info="Weight preservation (0 = full removal, 1 = no change)",
4536
+ )
4537
+ push_refine_passes = gr.Slider(
4538
+ 0, 3, value=0, step=1,
4539
+ label="Extra Refinement Passes",
4540
+ info="0 = skip refinement, 1-3 = apply additional passes",
4541
+ )
4542
+ push_refine_enabled = gr.Checkbox(
4543
+ label="Apply refinement before pushing",
4544
+ value=False,
4545
+ )
4546
+
4547
+ push_btn = gr.Button(
4548
+ "Push to Hub",
4549
+ variant="primary",
4550
+ size="lg",
4551
+ )
4552
+ push_status = gr.Markdown("")
4553
+ push_link = gr.Markdown("")
4554
+
4555
+ # -- Event wiring (inline since components are scoped to this tab) --
4556
+
4557
+ push_refresh_btn.click(
4558
+ fn=lambda: gr.update(choices=_get_session_model_choices()),
4559
+ outputs=[push_session_dd],
4560
+ )
4561
+
4562
+ push_session_dd.change(
4563
+ fn=lambda label: (_get_hub_session_info(label), _auto_hub_repo_id(label)),
4564
+ inputs=[push_session_dd],
4565
+ outputs=[push_model_info, push_repo_id],
4566
+ )
4567
+
4568
+ push_repo_id.change(
4569
+ fn=_validate_hub_repo,
4570
+ inputs=[push_repo_id],
4571
+ outputs=[push_repo_warning],
4572
+ )
4573
+
4574
+ push_btn.click(
4575
+ fn=push_session_to_hub,
4576
+ inputs=[push_session_dd, push_repo_id, push_token,
4577
+ push_refine_enabled, push_refine_reg, push_refine_passes],
4578
+ outputs=[push_status, push_link],
4579
+ )
4580
+
4581
+ # ── Tab: Leaderboard ────────────────────────────────────────────
4582
  with gr.Tab("Leaderboard", id="leaderboard"):
4583
  gr.Markdown("""### Community Leaderboard
4584
  All benchmark results from **every OBLITERATUS Space** (including duplicated copies) are
 
4804
  outputs=[prompt_vol_dd, dataset_info_md],
4805
  )
4806
 
 
 
 
 
 
 
4807
 
4808
  # Wire benchmark → Chat/A/B cross-tab dropdown updates
4809
  bench_btn.click(
 
4852
  # may not fire after generator teardown.
4853
  obliterate_btn.click(
4854
  fn=obliterate,
4855
+ inputs=[model_dd, method_dd, prompt_vol_dd, dataset_dd,
4856
  custom_harmful_tb, custom_harmless_tb] + _adv_controls,
4857
  outputs=[status_md, log_box, chat_status, session_model_dd, metrics_md, ab_session_model_dd],
4858
  ).then(
obliteratus/.DS_Store CHANGED
Binary files a/obliteratus/.DS_Store and b/obliteratus/.DS_Store differ
 
obliteratus/abliterate.py CHANGED
@@ -33,11 +33,12 @@ from typing import Any, Callable
33
  import torch
34
  import torch.nn as nn
35
 
 
 
36
  # Reduce CUDA memory fragmentation for large models. Must be set before any
37
  # CUDA allocations, so we do it at import time. This is the PyTorch-recommended
38
  # fix for "reserved but unallocated" memory issues.
39
- if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ and torch.cuda.is_available():
40
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
 
42
  from obliteratus.models.loader import ModelHandle, load_model # noqa: E402
43
  from obliteratus.strategies.utils import ( # noqa: E402
@@ -788,16 +789,8 @@ class AbliterationPipeline:
788
 
789
  @staticmethod
790
  def _free_gpu_memory():
791
- """Release unused GPU memory between pipeline stages."""
792
- import gc
793
- gc.collect()
794
- if torch.cuda.is_available():
795
- try:
796
- torch.cuda.empty_cache()
797
- except Exception:
798
- # CUDA may be in an error state after illegal memory access;
799
- # swallow so we don't cascade into every subsequent stage.
800
- pass
801
 
802
  @staticmethod
803
  def _get_model_device(model: nn.Module) -> torch.device:
@@ -1404,12 +1397,8 @@ class AbliterationPipeline:
1404
  max_length = self.max_seq_length
1405
  else:
1406
  max_length = 384 if collect_multi_pos else 256
1407
- free_gb = 0.0
1408
- if torch.cuda.is_available():
1409
- free_gb = sum(
1410
- torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
1411
- for i in range(torch.cuda.device_count())
1412
- )
1413
  if self.max_seq_length is None and free_gb < 2.0:
1414
  max_length = 64
1415
  self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
@@ -1999,22 +1988,22 @@ class AbliterationPipeline:
1999
  # Memory-aware cap: SAE encoder+decoder use
2000
  # 2 * hidden * (expansion * hidden) * 4 bytes
2001
  sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
2002
- if torch.cuda.is_available():
2003
  try:
2004
- free_mb = torch.cuda.mem_get_info()[0] / 1e6
2005
  # Leave 512 MB headroom for other ops
2006
  while sae_mem_mb > (free_mb - 512) and sae_expansion > 1:
2007
  sae_expansion //= 2
2008
  sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
2009
  except Exception:
2010
  pass # Fallback to hidden_dim-based heuristic
2011
- # Use GPU when enough headroom exists (SAE is small relative to model)
2012
  sae_device = "cpu"
2013
- if torch.cuda.is_available():
2014
  try:
2015
- sae_free_mb = torch.cuda.mem_get_info()[0] / 1e6
2016
  if sae_free_mb > sae_mem_mb + 1024:
2017
- sae_device = "cuda"
2018
  except Exception:
2019
  pass
2020
  sae = train_sae(
@@ -4155,7 +4144,8 @@ class AbliterationPipeline:
4155
  continue
4156
  original_norm = saved_norms[param_name]
4157
  if original_norm > 0:
4158
- data = param.data.float() if not param.data.is_floating_point() else param.data
 
4159
  new_norm = data.norm().item()
4160
  if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
4161
  continue # Skip — weight is degenerate after projection
@@ -4165,7 +4155,12 @@ class AbliterationPipeline:
4165
  # layers. Uncapped amplification destroys coherence.
4166
  if ratio > _MAX_NORM_RATIO:
4167
  ratio = _MAX_NORM_RATIO
4168
- param.data.mul_(ratio)
 
 
 
 
 
4169
 
4170
  @staticmethod
4171
  def _project_out_advanced(
@@ -5371,16 +5366,19 @@ class AbliterationPipeline:
5371
  unique_ratio = len(set(words)) / len(words)
5372
  if unique_ratio > 0.2:
5373
  coherent_count += 1
5374
- except torch.cuda.OutOfMemoryError:
5375
- self._free_gpu_memory()
5376
- self.log(" Skipping generation tests (CUDA out of memory — model too large for KV cache)")
5377
- generation_failed = True
5378
  except (RuntimeError, Exception) as e:
5379
- err_msg = str(e)
5380
- if "CUDA" in err_msg or "illegal" in err_msg.lower():
5381
  self._free_gpu_memory()
5382
- self.log(f" Skipping generation tests (CUDA error: {err_msg[:120]})")
5383
  generation_failed = True
 
 
 
 
 
 
 
 
5384
  else:
5385
  raise
5386
 
@@ -5523,18 +5521,21 @@ class AbliterationPipeline:
5523
 
5524
  del inputs, outputs
5525
  self._free_gpu_memory()
5526
- except torch.cuda.OutOfMemoryError:
5527
- self._free_gpu_memory()
5528
- self.log(f" [batch {batch_start+1}-{batch_end}] CUDA OOM — stopping")
5529
- self.log(" Skipping remaining refusal tests (CUDA out of memory)")
5530
- oom_break = True
5531
  except (RuntimeError, Exception) as e:
5532
- err_msg = str(e)
5533
- if "CUDA" in err_msg or "illegal" in err_msg.lower():
5534
  self._free_gpu_memory()
5535
- self.log(f" [batch {batch_start+1}-{batch_end}] CUDA error — stopping")
5536
- self.log(f" Skipping remaining refusal tests (CUDA error: {err_msg[:120]})")
5537
  oom_break = True
 
 
 
 
 
 
 
 
 
5538
  else:
5539
  raise
5540
 
 
33
  import torch
34
  import torch.nn as nn
35
 
36
+ from obliteratus import device as dev # noqa: E402 — must import before CUDA setup
37
+
38
  # Reduce CUDA memory fragmentation for large models. Must be set before any
39
  # CUDA allocations, so we do it at import time. This is the PyTorch-recommended
40
  # fix for "reserved but unallocated" memory issues.
41
+ dev.configure_cuda_alloc()
 
42
 
43
  from obliteratus.models.loader import ModelHandle, load_model # noqa: E402
44
  from obliteratus.strategies.utils import ( # noqa: E402
 
789
 
790
  @staticmethod
791
  def _free_gpu_memory():
792
+ """Release unused GPU/accelerator memory between pipeline stages."""
793
+ dev.free_gpu_memory()
 
 
 
 
 
 
 
 
794
 
795
  @staticmethod
796
  def _get_model_device(model: nn.Module) -> torch.device:
 
1397
  max_length = self.max_seq_length
1398
  else:
1399
  max_length = 384 if collect_multi_pos else 256
1400
+ free_gb = dev.get_total_free_gb()
1401
+ if dev.is_gpu_available():
 
 
 
 
1402
  if self.max_seq_length is None and free_gb < 2.0:
1403
  max_length = 64
1404
  self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
 
1988
  # Memory-aware cap: SAE encoder+decoder use
1989
  # 2 * hidden * (expansion * hidden) * 4 bytes
1990
  sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
1991
+ if dev.is_gpu_available():
1992
  try:
1993
+ free_mb = dev.get_total_free_gb() * 1024
1994
  # Leave 512 MB headroom for other ops
1995
  while sae_mem_mb > (free_mb - 512) and sae_expansion > 1:
1996
  sae_expansion //= 2
1997
  sae_mem_mb = 2 * hidden_dim * (sae_expansion * hidden_dim) * 4 / 1e6
1998
  except Exception:
1999
  pass # Fallback to hidden_dim-based heuristic
2000
+ # Use GPU/MPS when enough headroom exists (SAE is small relative to model)
2001
  sae_device = "cpu"
2002
+ if dev.is_gpu_available():
2003
  try:
2004
+ sae_free_mb = dev.get_total_free_gb() * 1024
2005
  if sae_free_mb > sae_mem_mb + 1024:
2006
+ sae_device = dev.get_device()
2007
  except Exception:
2008
  pass
2009
  sae = train_sae(
 
4144
  continue
4145
  original_norm = saved_norms[param_name]
4146
  if original_norm > 0:
4147
+ needs_cast = not param.data.is_floating_point()
4148
+ data = param.data.float() if needs_cast else param.data
4149
  new_norm = data.norm().item()
4150
  if math.isnan(new_norm) or math.isinf(new_norm) or new_norm == 0:
4151
  continue # Skip — weight is degenerate after projection
 
4155
  # layers. Uncapped amplification destroys coherence.
4156
  if ratio > _MAX_NORM_RATIO:
4157
  ratio = _MAX_NORM_RATIO
4158
+ if needs_cast:
4159
+ # Non-float dtypes (e.g. uint8) can't mul_ by a float
4160
+ # scalar in-place — rescale in float then cast back.
4161
+ param.data.copy_(data.mul_(ratio).to(param.data.dtype))
4162
+ else:
4163
+ param.data.mul_(ratio)
4164
 
4165
  @staticmethod
4166
  def _project_out_advanced(
 
5366
  unique_ratio = len(set(words)) / len(words)
5367
  if unique_ratio > 0.2:
5368
  coherent_count += 1
 
 
 
 
5369
  except (RuntimeError, Exception) as e:
5370
+ if dev.is_oom_error(e):
 
5371
  self._free_gpu_memory()
5372
+ self.log(" Skipping generation tests (out of memory — model too large for KV cache)")
5373
  generation_failed = True
5374
+ elif isinstance(e, RuntimeError):
5375
+ err_msg = str(e)
5376
+ if "CUDA" in err_msg or "MPS" in err_msg or "illegal" in err_msg.lower():
5377
+ self._free_gpu_memory()
5378
+ self.log(f" Skipping generation tests (device error: {err_msg[:120]})")
5379
+ generation_failed = True
5380
+ else:
5381
+ raise
5382
  else:
5383
  raise
5384
 
 
5521
 
5522
  del inputs, outputs
5523
  self._free_gpu_memory()
 
 
 
 
 
5524
  except (RuntimeError, Exception) as e:
5525
+ if dev.is_oom_error(e):
 
5526
  self._free_gpu_memory()
5527
+ self.log(f" [batch {batch_start+1}-{batch_end}] OOM — stopping")
5528
+ self.log(" Skipping remaining refusal tests (out of memory)")
5529
  oom_break = True
5530
+ elif isinstance(e, RuntimeError):
5531
+ err_msg = str(e)
5532
+ if "CUDA" in err_msg or "MPS" in err_msg or "illegal" in err_msg.lower():
5533
+ self._free_gpu_memory()
5534
+ self.log(f" [batch {batch_start+1}-{batch_end}] device error — stopping")
5535
+ self.log(f" Skipping remaining refusal tests (device error: {err_msg[:120]})")
5536
+ oom_break = True
5537
+ else:
5538
+ raise
5539
  else:
5540
  raise
5541
 
obliteratus/analysis/sae_abliteration.py CHANGED
@@ -39,6 +39,7 @@ from dataclasses import dataclass
39
 
40
  import torch
41
  import torch.nn as nn
 
42
 
43
 
44
  @dataclass
@@ -120,11 +121,11 @@ def _auto_detect_device(device: str | None = None) -> str:
120
  """
121
  if device is not None and device not in ("auto",):
122
  return device
123
- if torch.cuda.is_available():
124
  try:
125
- free_mb = torch.cuda.mem_get_info()[0] / 1e6
126
  if free_mb > 512:
127
- return "cuda"
128
  except Exception:
129
  pass
130
  return "cpu"
 
39
 
40
  import torch
41
  import torch.nn as nn
42
+ from obliteratus import device as dev
43
 
44
 
45
  @dataclass
 
121
  """
122
  if device is not None and device not in ("auto",):
123
  return device
124
+ if dev.is_gpu_available():
125
  try:
126
+ free_mb = dev.get_total_free_gb() * 1024
127
  if free_mb > 512:
128
+ return dev.get_device()
129
  except Exception:
130
  pass
131
  return "cpu"
obliteratus/device.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified device abstraction for CUDA, MPS (Apple Silicon), and CPU.
2
+
3
+ All device-specific queries (availability, memory, cache management) go through
4
+ this module so the rest of the codebase never calls ``torch.cuda.*`` directly.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import gc
10
+ import logging
11
+ import os
12
+ import platform
13
+ from dataclasses import dataclass
14
+
15
+ import torch
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Device detection
21
+ # ---------------------------------------------------------------------------
22
+
23
+ def is_cuda() -> bool:
24
+ """True when at least one NVIDIA CUDA GPU is visible."""
25
+ return torch.cuda.is_available()
26
+
27
+
28
+ def is_mps() -> bool:
29
+ """True when Apple Metal Performance Shaders backend is usable."""
30
+ return (
31
+ hasattr(torch.backends, "mps")
32
+ and torch.backends.mps.is_available()
33
+ and torch.backends.mps.is_built()
34
+ )
35
+
36
+
37
+ def is_gpu_available() -> bool:
38
+ """True if *any* GPU backend (CUDA or MPS) is available."""
39
+ return is_cuda() or is_mps()
40
+
41
+
42
+ def get_device(preference: str = "auto") -> str:
43
+ """Resolve a device string.
44
+
45
+ Parameters
46
+ ----------
47
+ preference : str
48
+ ``"auto"`` picks the best GPU, ``"cuda"``/``"mps"``/``"cpu"`` forces.
49
+
50
+ Returns
51
+ -------
52
+ str
53
+ A PyTorch device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
54
+ """
55
+ if preference == "auto":
56
+ if is_cuda():
57
+ return "cuda"
58
+ if is_mps():
59
+ return "mps"
60
+ return "cpu"
61
+ return preference
62
+
63
+
64
+ def get_device_name() -> str:
65
+ """Human-readable name of the current accelerator."""
66
+ if is_cuda():
67
+ return torch.cuda.get_device_name(0)
68
+ if is_mps():
69
+ # Apple doesn't expose a per-chip name via MPS; use platform info.
70
+ chip = platform.processor() or "Apple Silicon"
71
+ return f"Apple {chip} (MPS)"
72
+ return "CPU"
73
+
74
+
75
+ def device_count() -> int:
76
+ """Number of accelerator devices (GPUs or MPS slots)."""
77
+ if is_cuda():
78
+ return torch.cuda.device_count()
79
+ if is_mps():
80
+ return 1 # MPS always exposes a single unified device
81
+ return 0
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Memory information
86
+ # ---------------------------------------------------------------------------
87
+
88
+ @dataclass
89
+ class MemoryInfo:
90
+ """Snapshot of accelerator memory (in GB)."""
91
+
92
+ used_gb: float = 0.0
93
+ reserved_gb: float = 0.0
94
+ total_gb: float = 0.0
95
+ free_gb: float = 0.0
96
+ device_name: str = "CPU"
97
+
98
+
99
+ def _system_memory_gb() -> tuple[float, float]:
100
+ """Return (total_gb, available_gb) of system RAM."""
101
+ try:
102
+ import psutil
103
+ vm = psutil.virtual_memory()
104
+ return vm.total / 1024 ** 3, vm.available / 1024 ** 3
105
+ except ImportError:
106
+ pass
107
+ try:
108
+ total = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / 1024 ** 3
109
+ # Rough estimate: assume 60 % available if we can't query
110
+ return total, total * 0.6
111
+ except (AttributeError, ValueError):
112
+ return 16.0, 8.0 # conservative fallback
113
+
114
+
115
+ def get_memory_info(device_index: int = 0) -> MemoryInfo:
116
+ """Query memory for the given accelerator (or system RAM for MPS/CPU)."""
117
+ name = get_device_name()
118
+
119
+ if is_cuda():
120
+ try:
121
+ free, total = torch.cuda.mem_get_info(device_index)
122
+ used = torch.cuda.memory_allocated(device_index)
123
+ reserved = torch.cuda.memory_reserved(device_index)
124
+ total_gb = total / 1024 ** 3
125
+ return MemoryInfo(
126
+ used_gb=used / 1024 ** 3,
127
+ reserved_gb=reserved / 1024 ** 3,
128
+ total_gb=total_gb,
129
+ free_gb=free / 1024 ** 3,
130
+ device_name=name,
131
+ )
132
+ except Exception:
133
+ props = torch.cuda.get_device_properties(device_index)
134
+ total_gb = props.total_memory / 1024 ** 3
135
+ return MemoryInfo(total_gb=total_gb, free_gb=total_gb, device_name=name)
136
+
137
+ if is_mps():
138
+ # MPS uses unified memory — report system RAM as a proxy.
139
+ total, avail = _system_memory_gb()
140
+ # Apple's unified memory is shared with the OS, so usable fraction
141
+ # is typically ~65-75 % of total.
142
+ usable = total * 0.70
143
+ return MemoryInfo(
144
+ used_gb=max(usable - avail, 0.0),
145
+ reserved_gb=0.0,
146
+ total_gb=usable,
147
+ free_gb=min(avail, usable),
148
+ device_name=name,
149
+ )
150
+
151
+ # CPU-only
152
+ total, avail = _system_memory_gb()
153
+ return MemoryInfo(total_gb=total, free_gb=avail, device_name=name)
154
+
155
+
156
+ def get_total_free_gb() -> float:
157
+ """Sum of free memory across all accelerator devices, in GB."""
158
+ if is_cuda():
159
+ total_free = 0.0
160
+ for i in range(torch.cuda.device_count()):
161
+ try:
162
+ free, _ = torch.cuda.mem_get_info(i)
163
+ total_free += free / 1024 ** 3
164
+ except Exception:
165
+ props = torch.cuda.get_device_properties(i)
166
+ total_free += props.total_memory / 1024 ** 3
167
+ return total_free
168
+ if is_mps():
169
+ _, avail = _system_memory_gb()
170
+ return avail * 0.70 # usable fraction
171
+ return 0.0
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Cache / memory management
176
+ # ---------------------------------------------------------------------------
177
+
178
+ def empty_cache() -> None:
179
+ """Release cached allocations on the current accelerator."""
180
+ if is_cuda():
181
+ torch.cuda.empty_cache()
182
+ elif is_mps():
183
+ # torch.mps.empty_cache() available since PyTorch 2.1
184
+ if hasattr(torch.mps, "empty_cache"):
185
+ torch.mps.empty_cache()
186
+
187
+
188
+ def free_gpu_memory() -> None:
189
+ """Aggressive memory cleanup: GC + accelerator cache flush."""
190
+ gc.collect()
191
+ if is_cuda():
192
+ try:
193
+ torch.cuda.empty_cache()
194
+ except Exception:
195
+ try:
196
+ torch.cuda.synchronize()
197
+ except Exception:
198
+ pass
199
+ try:
200
+ torch.cuda.reset_peak_memory_stats()
201
+ except Exception:
202
+ pass
203
+ elif is_mps():
204
+ if hasattr(torch.mps, "empty_cache"):
205
+ try:
206
+ torch.mps.empty_cache()
207
+ except Exception:
208
+ pass
209
+ if hasattr(torch.mps, "synchronize"):
210
+ try:
211
+ torch.mps.synchronize()
212
+ except Exception:
213
+ pass
214
+
215
+
216
+ def set_seed_all(seed: int) -> None:
217
+ """Set random seed on all available accelerators."""
218
+ torch.manual_seed(seed)
219
+ if is_cuda():
220
+ torch.cuda.manual_seed_all(seed)
221
+ # MPS shares the CPU random state — no separate seed call needed.
222
+
223
+
224
+ # ---------------------------------------------------------------------------
225
+ # Dtype helpers
226
+ # ---------------------------------------------------------------------------
227
+
228
+ def default_dtype(device: str | None = None) -> torch.dtype:
229
+ """Sensible default dtype for the given device."""
230
+ dev = device or get_device()
231
+ if dev == "cpu":
232
+ return torch.float32
233
+ return torch.float16
234
+
235
+
236
+ def supports_bfloat16(device: str | None = None) -> bool:
237
+ """Whether *bfloat16* is natively supported on the target device."""
238
+ dev = device or get_device()
239
+ if dev.startswith("cuda"):
240
+ if is_cuda():
241
+ major, _ = torch.cuda.get_device_capability(0)
242
+ return major >= 8 # Ampere+
243
+ return False
244
+ if dev == "mps":
245
+ # MPS added bfloat16 support in PyTorch 2.3+
246
+ return hasattr(torch, "__version__") and tuple(
247
+ int(x) for x in torch.__version__.split(".")[:2]
248
+ ) >= (2, 3)
249
+ return True # CPU supports bfloat16 on most modern platforms
250
+
251
+
252
+ def supports_float64(device: str | None = None) -> bool:
253
+ """Whether *float64* is supported (MPS does NOT support it)."""
254
+ dev = device or get_device()
255
+ return dev != "mps"
256
+
257
+
258
+ def safe_svd_dtype(tensor: torch.Tensor) -> torch.dtype:
259
+ """Return a dtype safe for SVD on the tensor's device.
260
+
261
+ MPS does not support float64, so we cap at float32.
262
+ """
263
+ if tensor.device.type == "mps":
264
+ return torch.float32
265
+ return torch.float64 if tensor.dtype in (torch.float64, torch.float32) else torch.float32
266
+
267
+
268
+ # ---------------------------------------------------------------------------
269
+ # OOM exception matching
270
+ # ---------------------------------------------------------------------------
271
+
272
+ def is_oom_error(exc: BaseException) -> bool:
273
+ """Return True if *exc* is an out-of-memory error on any backend."""
274
+ if isinstance(exc, torch.cuda.OutOfMemoryError):
275
+ return True
276
+ # MPS raises a generic RuntimeError containing "out of memory"
277
+ if isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower():
278
+ return True
279
+ return False
280
+
281
+
282
+ # ---------------------------------------------------------------------------
283
+ # Quantization compatibility
284
+ # ---------------------------------------------------------------------------
285
+
286
+ def supports_bitsandbytes(device: str | None = None) -> bool:
287
+ """BitsAndBytes requires NVIDIA CUDA — check that."""
288
+ dev = device or get_device()
289
+ return dev.startswith("cuda")
290
+
291
+
292
+ def supports_device_map_auto(device: str | None = None) -> bool:
293
+ """Accelerate's device_map='auto' is only reliable on CUDA."""
294
+ dev = device or get_device()
295
+ return dev.startswith("cuda")
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # CUDA env setup (called once at import time of abliterate.py)
300
+ # ---------------------------------------------------------------------------
301
+
302
+ def configure_cuda_alloc() -> None:
303
+ """Set expandable_segments for CUDA if available."""
304
+ if is_cuda() and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
305
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
obliteratus/evaluation/benchmarks.py CHANGED
@@ -26,6 +26,7 @@ import re
26
  from dataclasses import dataclass, field
27
 
28
  import torch
 
29
 
30
 
31
  @dataclass
@@ -261,8 +262,7 @@ class BenchmarkRunner:
261
  ("math_reasoning", self.run_math_reasoning_probe)]:
262
  results[name] = fn()
263
  # Free KV caches between probes to prevent OOM on tight GPUs
264
- if torch.cuda.is_available():
265
- torch.cuda.empty_cache()
266
  return results
267
 
268
  def _answer_mcq(self, question: str, choices: list[str]) -> int:
 
26
  from dataclasses import dataclass, field
27
 
28
  import torch
29
+ from obliteratus import device as dev
30
 
31
 
32
  @dataclass
 
262
  ("math_reasoning", self.run_math_reasoning_probe)]:
263
  results[name] = fn()
264
  # Free KV caches between probes to prevent OOM on tight GPUs
265
+ dev.empty_cache()
 
266
  return results
267
 
268
  def _answer_mcq(self, question: str, choices: list[str]) -> int:
obliteratus/evaluation/heretic_eval.py CHANGED
@@ -32,6 +32,7 @@ from typing import TYPE_CHECKING
32
 
33
  import torch
34
  import torch.nn.functional as F
 
35
 
36
  if TYPE_CHECKING:
37
  from collections.abc import Callable
@@ -363,8 +364,7 @@ def unload_harmbench_classifier() -> None:
363
  model, tokenizer = _HARMBENCH_CLASSIFIER
364
  del model, tokenizer
365
  _HARMBENCH_CLASSIFIER = None
366
- if torch.cuda.is_available():
367
- torch.cuda.empty_cache()
368
  logger.info("HarmBench classifier unloaded")
369
 
370
 
@@ -432,8 +432,7 @@ def harmbench_asr(
432
 
433
  # Free memory between batches
434
  del inputs, outputs
435
- if torch.cuda.is_available():
436
- torch.cuda.empty_cache()
437
 
438
  n_successful = sum(per_item)
439
  return {
@@ -536,8 +535,7 @@ def first_token_kl_on_prompts(
536
  kl_values.extend(kl.cpu().tolist())
537
 
538
  del inputs_orig, inputs_mod, logits_orig, logits_mod, first_orig, first_mod
539
- if torch.cuda.is_available():
540
- torch.cuda.empty_cache()
541
 
542
  mean_kl = statistics.mean(kl_values) if kl_values else 0.0
543
  std_kl = statistics.stdev(kl_values) if len(kl_values) > 1 else 0.0
@@ -1098,8 +1096,8 @@ def run_full_heretic_eval(
1098
  completions.append("")
1099
 
1100
  del inputs
1101
- if i % 25 == 0 and torch.cuda.is_available():
1102
- torch.cuda.empty_cache()
1103
 
1104
  log(f"Generated {len(completions)} completions")
1105
 
 
32
 
33
  import torch
34
  import torch.nn.functional as F
35
+ from obliteratus import device as dev
36
 
37
  if TYPE_CHECKING:
38
  from collections.abc import Callable
 
364
  model, tokenizer = _HARMBENCH_CLASSIFIER
365
  del model, tokenizer
366
  _HARMBENCH_CLASSIFIER = None
367
+ dev.empty_cache()
 
368
  logger.info("HarmBench classifier unloaded")
369
 
370
 
 
432
 
433
  # Free memory between batches
434
  del inputs, outputs
435
+ dev.empty_cache()
 
436
 
437
  n_successful = sum(per_item)
438
  return {
 
535
  kl_values.extend(kl.cpu().tolist())
536
 
537
  del inputs_orig, inputs_mod, logits_orig, logits_mod, first_orig, first_mod
538
+ dev.empty_cache()
 
539
 
540
  mean_kl = statistics.mean(kl_values) if kl_values else 0.0
541
  std_kl = statistics.stdev(kl_values) if len(kl_values) > 1 else 0.0
 
1096
  completions.append("")
1097
 
1098
  del inputs
1099
+ if i % 25 == 0:
1100
+ dev.empty_cache()
1101
 
1102
  log(f"Generated {len(completions)} completions")
1103
 
obliteratus/interactive.py CHANGED
@@ -21,9 +21,10 @@ console = Console()
21
  def _detect_compute_tier() -> str:
22
  """Auto-detect the best compute tier based on available hardware."""
23
  try:
24
- import torch
25
 
26
- if torch.cuda.is_available():
 
27
  vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
28
  if vram_gb >= 20:
29
  return "large"
@@ -31,8 +32,13 @@ def _detect_compute_tier() -> str:
31
  return "medium"
32
  else:
33
  return "small"
34
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
35
- return "small" # Apple Silicon conservative estimate
 
 
 
 
 
36
  except ImportError:
37
  pass
38
  return "tiny" # CPU only
@@ -237,12 +243,11 @@ def run_interactive():
237
  dtype = model_preset.recommended_dtype
238
  quantization = None
239
  try:
240
- import torch
241
 
242
- if torch.cuda.is_available():
243
- device = "auto"
244
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
245
- device = "mps"
246
  except ImportError:
247
  pass
248
 
 
21
  def _detect_compute_tier() -> str:
22
  """Auto-detect the best compute tier based on available hardware."""
23
  try:
24
+ from obliteratus import device as dev
25
 
26
+ if dev.is_cuda():
27
+ import torch
28
  vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
29
  if vram_gb >= 20:
30
  return "large"
 
32
  return "medium"
33
  else:
34
  return "small"
35
+ elif dev.is_mps():
36
+ # Apple Silicon with unified memory — estimate from system RAM
37
+ mem = dev.get_memory_info()
38
+ if mem.total_gb >= 24:
39
+ return "medium" # M1 Pro/Max/Ultra, M2 Pro/Max/Ultra, M3 Pro/Max
40
+ else:
41
+ return "small" # M1/M2/M3 base (8-16 GB)
42
  except ImportError:
43
  pass
44
  return "tiny" # CPU only
 
243
  dtype = model_preset.recommended_dtype
244
  quantization = None
245
  try:
246
+ from obliteratus import device as _dev
247
 
248
+ resolved = _dev.get_device()
249
+ if resolved != "cpu":
250
+ device = resolved if resolved == "mps" else "auto"
 
251
  except ImportError:
252
  pass
253
 
obliteratus/mlx_backend.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optional MLX backend for Apple Silicon native inference and weight editing.
2
+
3
+ MLX is Apple's array framework that runs natively on the Apple Neural Engine
4
+ and Metal GPU. When available, it provides significantly faster inference and
5
+ weight manipulation than PyTorch's MPS backend on Apple hardware.
6
+
7
+ This module is entirely optional — if ``mlx`` / ``mlx-lm`` are not installed,
8
+ ``MLX_AVAILABLE`` is ``False`` and all public functions raise ``RuntimeError``.
9
+
10
+ Install with::
11
+
12
+ pip install mlx>=0.22 mlx-lm>=0.20
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from pathlib import Path
19
+ from typing import Any, Callable
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Availability check
25
+ # ---------------------------------------------------------------------------
26
+
27
+ MLX_AVAILABLE = False
28
+ _mx = None # mlx module
29
+ _mlx_lm = None # mlx-lm module
30
+ _mlx_nn = None # mlx.nn module
31
+
32
+ try:
33
+ import mlx.core as _mx_core # type: ignore[import-untyped]
34
+ import mlx.nn as _mlx_nn_mod # type: ignore[import-untyped]
35
+ import mlx_lm # type: ignore[import-untyped]
36
+
37
+ _mx = _mx_core
38
+ _mlx_nn = _mlx_nn_mod
39
+ _mlx_lm = mlx_lm
40
+ MLX_AVAILABLE = True
41
+ logger.info("MLX backend available (mlx %s)", _mx.__version__ if hasattr(_mx, "__version__") else "?")
42
+ except ImportError:
43
+ pass
44
+
45
+
46
+ def _require_mlx() -> None:
47
+ if not MLX_AVAILABLE:
48
+ raise RuntimeError(
49
+ "MLX backend is not available. Install with: pip install mlx>=0.22 mlx-lm>=0.20"
50
+ )
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Model loading
55
+ # ---------------------------------------------------------------------------
56
+
57
+ class MLXModelHandle:
58
+ """Lightweight wrapper around an MLX-loaded model + tokenizer."""
59
+
60
+ def __init__(self, model: Any, tokenizer: Any, model_name: str):
61
+ self.model = model
62
+ self.tokenizer = tokenizer
63
+ self.model_name = model_name
64
+
65
+ @property
66
+ def config(self) -> Any:
67
+ return getattr(self.model, "config", None)
68
+
69
+
70
+ def load_model(
71
+ model_name: str,
72
+ dtype: str = "float16",
73
+ ) -> MLXModelHandle:
74
+ """Load a HuggingFace model via ``mlx-lm`` for Apple-native execution.
75
+
76
+ Parameters
77
+ ----------
78
+ model_name : str
79
+ HuggingFace model identifier (e.g. ``"meta-llama/Llama-3.2-3B-Instruct"``).
80
+ dtype : str
81
+ One of ``"float16"``, ``"bfloat16"``, ``"float32"``.
82
+
83
+ Returns
84
+ -------
85
+ MLXModelHandle
86
+ Wrapper with ``.model`` and ``.tokenizer`` attributes.
87
+ """
88
+ _require_mlx()
89
+
90
+ from mlx_lm import load # type: ignore[import-untyped]
91
+
92
+ logger.info("Loading %s via MLX (dtype=%s)", model_name, dtype)
93
+ model, tokenizer = load(model_name)
94
+
95
+ return MLXModelHandle(model=model, tokenizer=tokenizer, model_name=model_name)
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Inference
100
+ # ---------------------------------------------------------------------------
101
+
102
+ def generate(
103
+ handle: MLXModelHandle,
104
+ prompt: str,
105
+ max_tokens: int = 256,
106
+ temperature: float = 0.7,
107
+ top_p: float = 0.9,
108
+ repetition_penalty: float | None = None,
109
+ ) -> str:
110
+ """Generate text using the MLX model.
111
+
112
+ Parameters
113
+ ----------
114
+ handle : MLXModelHandle
115
+ A loaded MLX model handle.
116
+ prompt : str
117
+ The input prompt string.
118
+ max_tokens : int
119
+ Maximum number of tokens to generate.
120
+ temperature : float
121
+ Sampling temperature.
122
+ top_p : float
123
+ Nucleus sampling threshold.
124
+ repetition_penalty : float or None
125
+ Repetition penalty (1.0 = no penalty).
126
+
127
+ Returns
128
+ -------
129
+ str
130
+ Generated text completion.
131
+ """
132
+ _require_mlx()
133
+
134
+ from mlx_lm import generate as mlx_generate # type: ignore[import-untyped]
135
+
136
+ kwargs: dict[str, Any] = {
137
+ "max_tokens": max_tokens,
138
+ "temp": temperature,
139
+ "top_p": top_p,
140
+ }
141
+ if repetition_penalty is not None:
142
+ kwargs["repetition_penalty"] = repetition_penalty
143
+
144
+ return mlx_generate(
145
+ handle.model,
146
+ handle.tokenizer,
147
+ prompt=prompt,
148
+ **kwargs,
149
+ )
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Activation extraction
154
+ # ---------------------------------------------------------------------------
155
+
156
+ def get_activations(
157
+ handle: MLXModelHandle,
158
+ prompts: list[str],
159
+ layer_indices: list[int],
160
+ max_length: int = 256,
161
+ ) -> dict[int, list[Any]]:
162
+ """Extract hidden-state activations from specified layers.
163
+
164
+ Uses MLX's computation graph to capture intermediate outputs.
165
+
166
+ Parameters
167
+ ----------
168
+ handle : MLXModelHandle
169
+ Loaded model.
170
+ prompts : list[str]
171
+ Input prompts to probe.
172
+ layer_indices : list[int]
173
+ Which transformer layers to capture.
174
+ max_length : int
175
+ Maximum sequence length for tokenization.
176
+
177
+ Returns
178
+ -------
179
+ dict[int, list[mlx.core.array]]
180
+ Mapping from layer index to list of activation arrays (one per prompt).
181
+ Each array has shape ``(hidden_dim,)`` — the last-token hidden state.
182
+ """
183
+ _require_mlx()
184
+ import mlx.core as mx # type: ignore[import-untyped]
185
+
186
+ model = handle.model
187
+ tokenizer = handle.tokenizer
188
+
189
+ # Identify the transformer block list
190
+ layers = None
191
+ for attr in ("model.layers", "transformer.h", "gpt_neox.layers"):
192
+ obj = model
193
+ try:
194
+ for part in attr.split("."):
195
+ obj = getattr(obj, part)
196
+ layers = obj
197
+ break
198
+ except AttributeError:
199
+ continue
200
+
201
+ if layers is None:
202
+ raise RuntimeError(
203
+ "Cannot locate transformer layers in the MLX model. "
204
+ "Supported architectures: LLaMA, GPT-2, GPT-NeoX."
205
+ )
206
+
207
+ activations: dict[int, list[Any]] = {idx: [] for idx in layer_indices}
208
+ target_set = set(layer_indices)
209
+
210
+ for prompt in prompts:
211
+ tokens = tokenizer.encode(prompt)
212
+ if len(tokens) > max_length:
213
+ tokens = tokens[:max_length]
214
+
215
+ input_ids = mx.array([tokens])
216
+
217
+ # Forward through embedding
218
+ if hasattr(model, "model"):
219
+ # LLaMA-style: model.model.embed_tokens
220
+ embed_module = model.model
221
+ elif hasattr(model, "transformer"):
222
+ embed_module = model.transformer
223
+ else:
224
+ embed_module = model
225
+
226
+ if hasattr(embed_module, "embed_tokens"):
227
+ h = embed_module.embed_tokens(input_ids)
228
+ elif hasattr(embed_module, "wte"):
229
+ h = embed_module.wte(input_ids)
230
+ else:
231
+ raise RuntimeError("Cannot find embedding layer in MLX model")
232
+
233
+ # Walk through layers, capturing activations at target indices
234
+ for i, layer in enumerate(layers):
235
+ h = layer(h)
236
+ # Some layers return tuples (hidden, attention) — take first
237
+ if isinstance(h, tuple):
238
+ h = h[0]
239
+
240
+ if i in target_set:
241
+ # Last token hidden state
242
+ last_hidden = h[0, -1, :]
243
+ mx.eval(last_hidden) # force evaluation
244
+ activations[i].append(last_hidden)
245
+
246
+ return activations
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Weight manipulation
251
+ # ---------------------------------------------------------------------------
252
+
253
+ def get_weight(handle: MLXModelHandle, layer_idx: int, param_path: str) -> Any:
254
+ """Retrieve a weight matrix from the model.
255
+
256
+ Parameters
257
+ ----------
258
+ handle : MLXModelHandle
259
+ Loaded model.
260
+ layer_idx : int
261
+ Transformer layer index.
262
+ param_path : str
263
+ Dot-separated path within the layer, e.g. ``"self_attn.o_proj.weight"``.
264
+
265
+ Returns
266
+ -------
267
+ mlx.core.array
268
+ The weight tensor.
269
+ """
270
+ _require_mlx()
271
+ model = handle.model
272
+
273
+ # Navigate to the layer
274
+ layers = _get_layers(model)
275
+ layer = layers[layer_idx]
276
+
277
+ # Navigate the param path
278
+ obj = layer
279
+ for part in param_path.split("."):
280
+ obj = getattr(obj, part)
281
+
282
+ return obj
283
+
284
+
285
+ def modify_weights(
286
+ handle: MLXModelHandle,
287
+ layer_idx: int,
288
+ param_path: str,
289
+ modifier_fn: Callable[[Any], Any],
290
+ ) -> None:
291
+ """Modify a weight matrix in-place using a function.
292
+
293
+ Parameters
294
+ ----------
295
+ handle : MLXModelHandle
296
+ Loaded model.
297
+ layer_idx : int
298
+ Transformer layer index.
299
+ param_path : str
300
+ Dot-separated path within the layer to the weight, e.g.
301
+ ``"self_attn.o_proj.weight"``.
302
+ modifier_fn : callable
303
+ Function that takes the current weight (mlx array) and returns the
304
+ modified weight (mlx array). For abliteration, this would project
305
+ out the refusal direction.
306
+ """
307
+ _require_mlx()
308
+ import mlx.core as mx # type: ignore[import-untyped]
309
+
310
+ model = handle.model
311
+ layers = _get_layers(model)
312
+ layer = layers[layer_idx]
313
+
314
+ # Navigate to the parent module and leaf attribute
315
+ parts = param_path.split(".")
316
+ parent = layer
317
+ for part in parts[:-1]:
318
+ parent = getattr(parent, part)
319
+ leaf_name = parts[-1]
320
+
321
+ old_weight = getattr(parent, leaf_name)
322
+ new_weight = modifier_fn(old_weight)
323
+
324
+ # MLX uses a functional update pattern
325
+ if hasattr(parent, "update"):
326
+ parent.update({leaf_name: new_weight})
327
+ else:
328
+ setattr(parent, leaf_name, new_weight)
329
+
330
+ mx.eval(new_weight) # materialize
331
+
332
+
333
+ def project_out_direction(weight: Any, direction: Any) -> Any:
334
+ """Project a direction out of a weight matrix (abliteration).
335
+
336
+ Given weight matrix W and unit direction d, computes::
337
+
338
+ W' = W - (W @ d) outer d
339
+
340
+ Parameters
341
+ ----------
342
+ weight : mlx.core.array
343
+ Weight matrix, shape ``(out_features, in_features)``.
344
+ direction : mlx.core.array
345
+ Unit direction vector, shape ``(in_features,)``.
346
+
347
+ Returns
348
+ -------
349
+ mlx.core.array
350
+ Modified weight with direction projected out.
351
+ """
352
+ _require_mlx()
353
+ import mlx.core as mx # type: ignore[import-untyped]
354
+
355
+ d = direction.astype(weight.dtype)
356
+ # W @ d gives the component along d for each row
357
+ proj = mx.matmul(weight, d[:, None]) # (out, 1)
358
+ return weight - mx.matmul(proj, d[None, :]) # (out, in)
359
+
360
+
361
+ # ---------------------------------------------------------------------------
362
+ # Save model
363
+ # ---------------------------------------------------------------------------
364
+
365
+ def save_model(
366
+ handle: MLXModelHandle,
367
+ output_dir: str | Path,
368
+ upload_repo: str | None = None,
369
+ ) -> Path:
370
+ """Save the (modified) MLX model to disk.
371
+
372
+ Saves in safetensors format compatible with both MLX and HuggingFace.
373
+
374
+ Parameters
375
+ ----------
376
+ handle : MLXModelHandle
377
+ Model handle (possibly with modified weights).
378
+ output_dir : str or Path
379
+ Directory to save into.
380
+ upload_repo : str or None
381
+ If set, also uploads to HuggingFace Hub.
382
+
383
+ Returns
384
+ -------
385
+ Path
386
+ The output directory.
387
+ """
388
+ _require_mlx()
389
+
390
+ from mlx_lm import convert # type: ignore[import-untyped]
391
+
392
+ out = Path(output_dir)
393
+ out.mkdir(parents=True, exist_ok=True)
394
+
395
+ # mlx-lm's save uses safetensors
396
+ if hasattr(_mlx_lm, "save_model"):
397
+ _mlx_lm.save_model(str(out), handle.model, handle.tokenizer)
398
+ else:
399
+ # Fallback: manual save via mlx.core.save_safetensors
400
+ import mlx.core as mx # type: ignore[import-untyped]
401
+ weights = dict(handle.model.parameters())
402
+ flat = {}
403
+ _flatten_dict(weights, "", flat)
404
+ mx.save_safetensors(str(out / "model.safetensors"), flat)
405
+ # Save tokenizer via transformers
406
+ handle.tokenizer.save_pretrained(str(out))
407
+
408
+ logger.info("MLX model saved to %s", out)
409
+
410
+ if upload_repo:
411
+ try:
412
+ from mlx_lm import upload_to_hub # type: ignore[import-untyped]
413
+ upload_to_hub(str(out), upload_repo)
414
+ logger.info("Uploaded to %s", upload_repo)
415
+ except (ImportError, AttributeError):
416
+ logger.warning("mlx-lm upload not available — push manually with huggingface-cli")
417
+
418
+ return out
419
+
420
+
421
+ # ---------------------------------------------------------------------------
422
+ # Conversion: PyTorch ↔ MLX
423
+ # ---------------------------------------------------------------------------
424
+
425
+ def torch_tensor_to_mlx(tensor: "torch.Tensor") -> Any: # noqa: F821
426
+ """Convert a PyTorch tensor to an MLX array."""
427
+ _require_mlx()
428
+ import mlx.core as mx # type: ignore[import-untyped]
429
+ import numpy as np
430
+
431
+ # Move to CPU and convert via numpy
432
+ np_array = tensor.detach().cpu().float().numpy()
433
+ return mx.array(np_array)
434
+
435
+
436
+ def mlx_to_torch_tensor(array: Any, device: str = "cpu") -> "torch.Tensor": # noqa: F821
437
+ """Convert an MLX array to a PyTorch tensor."""
438
+ import numpy as np
439
+ import torch
440
+
441
+ np_array = np.array(array)
442
+ return torch.from_numpy(np_array).to(device)
443
+
444
+
445
+ # ---------------------------------------------------------------------------
446
+ # Internal helpers
447
+ # ---------------------------------------------------------------------------
448
+
449
+ def _get_layers(model: Any) -> Any:
450
+ """Locate the transformer block list in an MLX model."""
451
+ for attr_path in ("model.layers", "transformer.h", "gpt_neox.layers"):
452
+ obj = model
453
+ try:
454
+ for part in attr_path.split("."):
455
+ obj = getattr(obj, part)
456
+ return obj
457
+ except AttributeError:
458
+ continue
459
+ raise RuntimeError("Cannot locate transformer layers in MLX model")
460
+
461
+
462
+ def _flatten_dict(d: dict, prefix: str, out: dict) -> None:
463
+ """Flatten a nested dict with dot-separated keys."""
464
+ for k, v in d.items():
465
+ key = f"{prefix}{k}" if prefix else k
466
+ if isinstance(v, dict):
467
+ _flatten_dict(v, f"{key}.", out)
468
+ else:
469
+ out[key] = v
obliteratus/models/loader.py CHANGED
@@ -12,6 +12,7 @@ from typing import Optional
12
  import sys as _sys
13
 
14
  import torch
 
15
  from transformers import (
16
  AutoConfig,
17
  AutoModelForCausalLM,
@@ -381,24 +382,8 @@ def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
381
 
382
 
383
  def _available_gpu_memory_gb() -> float:
384
- """Return free GPU memory across all CUDA devices, in GB.
385
-
386
- Uses torch.cuda.mem_get_info which reports actual free memory,
387
- not total capacity. Falls back to total_memory if mem_get_info
388
- is unavailable (PyTorch < 1.10).
389
- """
390
- if not torch.cuda.is_available():
391
- return 0.0
392
- total_free = 0.0
393
- for i in range(torch.cuda.device_count()):
394
- try:
395
- free, _ = torch.cuda.mem_get_info(i)
396
- total_free += free / (1024 ** 3)
397
- except AttributeError:
398
- # Fallback for old PyTorch without mem_get_info
399
- props = torch.cuda.get_device_properties(i)
400
- total_free += props.total_memory / (1024 ** 3)
401
- return total_free
402
 
403
 
404
  def _hf_token() -> str | None:
@@ -515,34 +500,54 @@ def load_model(
515
  load_kwargs.pop("torch_dtype", None)
516
  load_kwargs["device_map"] = "auto"
517
  elif quantization in ("4bit", "8bit"):
518
- try:
519
- import bitsandbytes # noqa: F401
520
- except ImportError:
521
- raise RuntimeError(
522
- f"Quantization '{quantization}' requires bitsandbytes: "
523
- f"pip install -U bitsandbytes>=0.46.1"
524
- )
525
- from transformers import BitsAndBytesConfig
526
-
527
- # Enable fp32 CPU offload so that models too large to fit entirely on
528
- # GPU (even quantized) can spill to CPU without crashing bitsandbytes.
529
- # This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
530
- # Mistral Large 3 675B, etc.) on single-GPU setups.
531
- if quantization == "4bit":
532
- load_kwargs["quantization_config"] = BitsAndBytesConfig(
533
- load_in_4bit=True,
534
- bnb_4bit_compute_dtype=torch_dtype,
535
- bnb_4bit_quant_type="nf4",
536
- llm_int8_enable_fp32_cpu_offload=True,
537
  )
 
 
 
 
538
  else:
539
- load_kwargs["quantization_config"] = BitsAndBytesConfig(
540
- load_in_8bit=True,
541
- llm_int8_enable_fp32_cpu_offload=True,
542
- )
543
- load_kwargs["device_map"] = "auto"
544
- elif device == "auto":
545
- load_kwargs["device_map"] = "auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  # Offload support: provide a folder for disk offloading when GPU memory is insufficient
548
  _offload_dir = None
@@ -560,9 +565,9 @@ def load_model(
560
  # Reserve GPU headroom for inference (KV cache, activations, generate()).
561
  # Without this, device_map="auto" packs 100% of layers onto GPU, leaving
562
  # no room for forward passes or generation on tight-memory setups.
563
- if torch.cuda.is_available():
564
  max_memory = {}
565
- for i in range(torch.cuda.device_count()):
566
  total = torch.cuda.get_device_properties(i).total_memory
567
  # Reserve 15% or 2 GiB (whichever is larger) for inference headroom
568
  reserve = max(int(total * 0.15), 2 * 1024 ** 3)
@@ -570,16 +575,8 @@ def load_model(
570
  max_memory[i] = f"{usable // (1024 ** 2)}MiB"
571
  # Allow overflow to CPU RAM, capped at 85% of physical memory
572
  # to leave room for the OS, Python runtime, and serialization buffers.
573
- try:
574
- import psutil
575
- cpu_ram_gb = psutil.virtual_memory().total / (1024 ** 3)
576
- except ImportError:
577
- try:
578
- cpu_ram_gb = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / (1024 ** 3)
579
- except (AttributeError, ValueError):
580
- # os.sysconf is unavailable on non-POSIX platforms (Windows)
581
- cpu_ram_gb = 16.0 # conservative fallback
582
- cpu_budget_gb = int(cpu_ram_gb * 0.85)
583
  max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
584
  load_kwargs["max_memory"] = max_memory
585
  logger.info(
@@ -625,12 +622,15 @@ def load_model(
625
 
626
  if device not in ("auto",) and quantization is None and native_quant is None:
627
  model = model.to(device)
 
 
 
 
628
 
629
  model.eval()
630
 
631
- # Free CUDA cache after loading
632
- if torch.cuda.is_available():
633
- torch.cuda.empty_cache()
634
 
635
  try:
636
  tokenizer = AutoTokenizer.from_pretrained(
@@ -665,9 +665,7 @@ def load_model(
665
  if gpu_gb > 0 and native_quant is not None:
666
  # Model is pre-quantized but we can't estimate its true size.
667
  # Check actual free memory after loading — if less than 40% free, skip snapshot.
668
- free_gb = 0.0
669
- for i in range(torch.cuda.device_count()):
670
- free_gb += torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
671
  if free_gb < gpu_gb * 0.4:
672
  logger.warning(
673
  f"Auto-skipping state dict snapshot for natively quantized model "
 
12
  import sys as _sys
13
 
14
  import torch
15
+ from obliteratus import device as dev
16
  from transformers import (
17
  AutoConfig,
18
  AutoModelForCausalLM,
 
382
 
383
 
384
  def _available_gpu_memory_gb() -> float:
385
+ """Return free accelerator memory in GB (CUDA, MPS, or 0 for CPU)."""
386
+ return dev.get_total_free_gb()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
 
389
  def _hf_token() -> str | None:
 
500
  load_kwargs.pop("torch_dtype", None)
501
  load_kwargs["device_map"] = "auto"
502
  elif quantization in ("4bit", "8bit"):
503
+ # BitsAndBytes only works on NVIDIA CUDA GPUs.
504
+ resolved_device = dev.get_device(device)
505
+ if not dev.supports_bitsandbytes(resolved_device):
506
+ logger.warning(
507
+ "BitsAndBytes quantization is not supported on %s. "
508
+ "Loading in %s instead.",
509
+ resolved_device, dtype,
 
 
 
 
 
 
 
 
 
 
 
 
510
  )
511
+ # On MPS, load normally to the device; on CPU, fall through.
512
+ if resolved_device == "mps":
513
+ device = "mps"
514
+ # Don't set quantization_config — fall through to normal loading.
515
  else:
516
+ try:
517
+ import bitsandbytes # noqa: F401
518
+ except ImportError:
519
+ raise RuntimeError(
520
+ f"Quantization '{quantization}' requires bitsandbytes: "
521
+ f"pip install -U bitsandbytes>=0.46.1"
522
+ )
523
+ from transformers import BitsAndBytesConfig
524
+
525
+ # Enable fp32 CPU offload so that models too large to fit entirely on
526
+ # GPU (even quantized) can spill to CPU without crashing bitsandbytes.
527
+ # This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
528
+ # Mistral Large 3 675B, etc.) on single-GPU setups.
529
+ if quantization == "4bit":
530
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
531
+ load_in_4bit=True,
532
+ bnb_4bit_compute_dtype=torch_dtype,
533
+ bnb_4bit_quant_type="nf4",
534
+ llm_int8_enable_fp32_cpu_offload=True,
535
+ )
536
+ else:
537
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
538
+ load_in_8bit=True,
539
+ llm_int8_enable_fp32_cpu_offload=True,
540
+ )
541
+ load_kwargs["device_map"] = "auto"
542
+
543
+ # device_map="auto" is only reliable on CUDA (accelerate doesn't support MPS).
544
+ if "device_map" not in load_kwargs and device == "auto":
545
+ resolved_device = dev.get_device(device)
546
+ if dev.supports_device_map_auto(resolved_device):
547
+ load_kwargs["device_map"] = "auto"
548
+ else:
549
+ # MPS / CPU: load to CPU first, then .to(device) after loading.
550
+ pass
551
 
552
  # Offload support: provide a folder for disk offloading when GPU memory is insufficient
553
  _offload_dir = None
 
565
  # Reserve GPU headroom for inference (KV cache, activations, generate()).
566
  # Without this, device_map="auto" packs 100% of layers onto GPU, leaving
567
  # no room for forward passes or generation on tight-memory setups.
568
+ if dev.is_cuda():
569
  max_memory = {}
570
+ for i in range(dev.device_count()):
571
  total = torch.cuda.get_device_properties(i).total_memory
572
  # Reserve 15% or 2 GiB (whichever is larger) for inference headroom
573
  reserve = max(int(total * 0.15), 2 * 1024 ** 3)
 
575
  max_memory[i] = f"{usable // (1024 ** 2)}MiB"
576
  # Allow overflow to CPU RAM, capped at 85% of physical memory
577
  # to leave room for the OS, Python runtime, and serialization buffers.
578
+ total_ram, _ = dev._system_memory_gb()
579
+ cpu_budget_gb = int(total_ram * 0.85)
 
 
 
 
 
 
 
 
580
  max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
581
  load_kwargs["max_memory"] = max_memory
582
  logger.info(
 
622
 
623
  if device not in ("auto",) and quantization is None and native_quant is None:
624
  model = model.to(device)
625
+ elif device == "auto" and not dev.supports_device_map_auto():
626
+ # MPS / CPU: device_map wasn't used, move model to best device.
627
+ resolved = dev.get_device()
628
+ model = model.to(resolved)
629
 
630
  model.eval()
631
 
632
+ # Free accelerator cache after loading
633
+ dev.empty_cache()
 
634
 
635
  try:
636
  tokenizer = AutoTokenizer.from_pretrained(
 
665
  if gpu_gb > 0 and native_quant is not None:
666
  # Model is pre-quantized but we can't estimate its true size.
667
  # Check actual free memory after loading — if less than 40% free, skip snapshot.
668
+ free_gb = dev.get_total_free_gb()
 
 
669
  if free_gb < gpu_gb * 0.4:
670
  logger.warning(
671
  f"Auto-skipping state dict snapshot for natively quantized model "
obliteratus/reproducibility.py CHANGED
@@ -38,9 +38,9 @@ def set_seed(seed: int = 42, deterministic: bool = True) -> None:
38
 
39
  try:
40
  import torch
 
41
  torch.manual_seed(seed)
42
- if torch.cuda.is_available():
43
- torch.cuda.manual_seed_all(seed)
44
 
45
  if deterministic:
46
  torch.use_deterministic_algorithms(True, warn_only=True)
 
38
 
39
  try:
40
  import torch
41
+ from obliteratus import device as dev
42
  torch.manual_seed(seed)
43
+ dev.set_seed_all(seed)
 
44
 
45
  if deterministic:
46
  torch.use_deterministic_algorithms(True, warn_only=True)
obliteratus/tourney.py CHANGED
@@ -372,9 +372,8 @@ class TourneyRunner:
372
  # Clean up GPU between methods
373
  gc.collect()
374
  try:
375
- import torch
376
- if torch.cuda.is_available():
377
- torch.cuda.empty_cache()
378
  except Exception:
379
  pass
380
 
 
372
  # Clean up GPU between methods
373
  gc.collect()
374
  try:
375
+ from obliteratus import device as dev
376
+ dev.empty_cache()
 
377
  except Exception:
378
  pass
379
 
requirements-apple.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Optional Apple Silicon dependencies for native MLX acceleration.
2
+ # Install alongside the main requirements on macOS with Apple Silicon:
3
+ #
4
+ # pip install -r requirements.txt -r requirements-apple.txt
5
+ #
6
+ # These packages are macOS-only and will fail to install on Linux/Windows.
7
+ mlx>=0.22
8
+ mlx-lm>=0.20