File size: 23,417 Bytes
ed1b365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
#!/usr/bin/env python3
"""Pipeline 2: CPU-Offload LoRA Training for Codette Adapters

Ultra-low-memory training using disk offloading and aggressive memory management.
Designed for machines where the model doesn't fit in physical RAM.
Relies heavily on Windows page file β€” the OS swaps model layers to/from disk.

Memory:  ~8-12 GB active RAM (rest swapped to page file)
Speed:   ~2-5 min per step (heavy disk I/O from page swapping)
Time:    ~6-24 hours per adapter (slow but reliable)

Key differences from Pipeline 1 (lean):
  - LoRA rank=4 (half the parameters)
  - Shorter sequences (128 tokens vs 256)
  - SGD optimizer (50% less memory than AdamW)
  - Aggressive garbage collection every step
  - Layer-by-layer model loading (lower peak RAM)
  - Memory monitoring with automatic abort if critical

Usage:
    python train_cpu_offload.py newton
    python train_cpu_offload.py empathy --epochs 2
    python train_cpu_offload.py --pagefile-info        # Show page file guidance
    python train_cpu_offload.py --list                 # Show available adapters
    python train_cpu_offload.py newton --resume        # Resume from checkpoint

IMPORTANT: Ensure your page file is at least 24 GB.
Run with --pagefile-info for setup instructions.
"""

import os, sys, time, json, gc, argparse, math
from pathlib import Path
from datetime import datetime, timedelta

# ── Environment bootstrap ───────────────────────────────────────
_site = r"J:\Lib\site-packages"
if _site not in sys.path:
    sys.path.insert(0, _site)
os.environ["PATH"] = (
    r"J:\Lib\site-packages\Library\bin" + os.pathsep + os.environ.get("PATH", "")
)
os.environ["HF_HOME"] = r"J:\hf_cache"
os.environ["TRANSFORMERS_CACHE"] = r"J:\hf_cache"

# Reduce torch memory overhead
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
os.environ["MALLOC_TRIM_THRESHOLD_"] = "0"

try:
    sys.stdout.reconfigure(encoding='utf-8', errors='replace')
except Exception:
    pass


# ── Set IDLE priority ──────────────────────────────────────────
def set_idle_priority():
    """Set process to IDLE priority β€” only runs when nothing else needs CPU."""
    try:
        import ctypes
        IDLE_PRIORITY = 0x00000040
        handle = ctypes.windll.kernel32.GetCurrentProcess()
        ctypes.windll.kernel32.SetPriorityClass(handle, IDLE_PRIORITY)
        print("  Process priority: IDLE (only uses spare CPU cycles)")
    except Exception:
        pass


# ── Memory monitoring ──────────────────────────────────────────
def get_memory_info():
    """Return dict with memory stats in GB."""
    try:
        import ctypes
        class MEMSTAT(ctypes.Structure):
            _fields_ = [
                ('dwLength', ctypes.c_ulong), ('dwMemoryLoad', ctypes.c_ulong),
                ('ullTotalPhys', ctypes.c_ulonglong), ('ullAvailPhys', ctypes.c_ulonglong),
                ('ullTotalPageFile', ctypes.c_ulonglong), ('ullAvailPageFile', ctypes.c_ulonglong),
                ('ullTotalVirtual', ctypes.c_ulonglong), ('ullAvailVirtual', ctypes.c_ulonglong),
                ('ullAvailExtendedVirtual', ctypes.c_ulonglong),
            ]
        m = MEMSTAT(dwLength=ctypes.sizeof(MEMSTAT))
        ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(m))
        return {
            "ram_used": (m.ullTotalPhys - m.ullAvailPhys) / 1e9,
            "ram_total": m.ullTotalPhys / 1e9,
            "ram_avail": m.ullAvailPhys / 1e9,
            "page_used": (m.ullTotalPageFile - m.ullAvailPageFile) / 1e9,
            "page_total": m.ullTotalPageFile / 1e9,
            "page_avail": m.ullAvailPageFile / 1e9,
            "pct": m.dwMemoryLoad,
        }
    except Exception:
        return {"ram_used": 0, "ram_total": 0, "ram_avail": 0,
                "page_used": 0, "page_total": 0, "page_avail": 0, "pct": 0}


def check_memory_safe(label=""):
    """Check memory and warn/abort if critically low."""
    info = get_memory_info()
    print(
        f"  [{label}] RAM: {info['ram_used']:.1f}/{info['ram_total']:.1f} GB "
        f"({info['pct']}%) | Page avail: {info['page_avail']:.1f} GB"
    )
    if info["page_avail"] < 2.0:
        print(f"\n  WARNING: Page file nearly full! ({info['page_avail']:.1f} GB left)")
        print(f"  Training may crash. Increase page file size or close other programs.")
        print(f"  Run: python train_cpu_offload.py --pagefile-info")
    return info


def aggressive_cleanup():
    """Force garbage collection and release memory back to OS."""
    gc.collect()
    gc.collect()
    # On Windows, try to trim working set
    try:
        import ctypes
        kernel32 = ctypes.windll.kernel32
        handle = kernel32.GetCurrentProcess()
        # SetProcessWorkingSetSize with -1, -1 trims the working set
        kernel32.SetProcessWorkingSetSize(handle, ctypes.c_size_t(-1), ctypes.c_size_t(-1))
    except Exception:
        pass


# ── Configuration ──────────────────────────────────────────────
PROJECT_ROOT = Path(r"J:\codette-training-lab")
DATASET_DIR = PROJECT_ROOT / "datasets"
ADAPTER_OUT = PROJECT_ROOT / "adapters"
CKPT_DIR = PROJECT_ROOT / "training" / "checkpoints_offload"
GGUF_CONVERTER = Path(r"J:\TheAI\llama.cpp\convert_lora_to_gguf.py")

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"

ADAPTER_CONFIG = {
    "newton":               {"dataset": "newton_reasoning.jsonl",               "examples": 3000},
    "davinci":              {"dataset": "davinci_reasoning.jsonl",              "examples": 2500},
    "empathy":              {"dataset": "empathy_reasoning.jsonl",              "examples": 2500},
    "philosophy":           {"dataset": "philosophy_reasoning.jsonl",           "examples": 2000},
    "quantum":              {"dataset": "quantum_reasoning.jsonl",              "examples": 2000},
    "consciousness":        {"dataset": "consciousness_reasoning.jsonl",        "examples": 3000},
    "multi_perspective":    {"dataset": "multi_perspective_reasoning.jsonl",    "examples": 2500},
    "systems_architecture": {"dataset": "systems_architecture_reasoning.jsonl", "examples": 2000},
}


# ── Dataset loading ────────────────────────────────────────────
def load_dataset_jsonl(adapter_name, max_examples=None):
    """Load chat-format JSONL dataset."""
    cfg = ADAPTER_CONFIG[adapter_name]
    path = DATASET_DIR / cfg["dataset"]
    if not path.exists():
        raise FileNotFoundError(f"Dataset not found: {path}")

    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))

    if max_examples and len(data) > max_examples:
        data = data[:max_examples]

    print(f"  Dataset: {path.name} ({len(data)} examples)")
    return data


def format_chat_to_text(messages, tokenizer):
    """Convert chat messages to training text."""
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    except Exception:
        parts = []
        for msg in messages:
            role, content = msg["role"], msg["content"]
            parts.append(f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>")
        return "<|begin_of_text|>" + "".join(parts)


# ── Training ───────────────────────────────────────────────────
def train_adapter_offload(
    adapter_name,
    epochs=2,
    rank=4,
    alpha=8,
    lr=1e-4,
    batch_size=1,
    grad_accum=8,
    max_seq_len=128,
    save_steps=50,
    resume=False,
    max_examples=None,
):
    """Train a LoRA adapter with extreme memory optimization."""

    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import LoraConfig, get_peft_model, TaskType

    set_idle_priority()
    check_memory_safe("startup")

    # ── Check page file is adequate ─────────────────────────
    info = get_memory_info()
    if info["page_total"] < 20:
        print(f"\n  WARNING: Page file is only {info['page_total']:.1f} GB.")
        print(f"  Recommend at least 24 GB for offload training.")
        print(f"  Run: python train_cpu_offload.py --pagefile-info")
        print(f"  Continuing anyway...\n")

    # ── Load tokenizer ──────────────────────────────────────
    print(f"\n  Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    aggressive_cleanup()

    # ── Pre-tokenize dataset BEFORE loading model ───────────
    # This way we can free the raw data before the model needs RAM
    print(f"  Pre-tokenizing dataset (before model load to save RAM)...")
    raw_data = load_dataset_jsonl(adapter_name, max_examples=max_examples)

    tokenized = []
    for item in raw_data:
        text = format_chat_to_text(item["messages"], tokenizer)
        tokens = tokenizer(
            text,
            truncation=True,
            max_length=max_seq_len,
            padding="max_length",
            return_tensors="pt",
        )
        if tokens["attention_mask"].sum().item() >= 10:
            tokenized.append({
                "input_ids": tokens["input_ids"].squeeze(0),
                "attention_mask": tokens["attention_mask"].squeeze(0),
                "labels": tokens["input_ids"].squeeze(0).clone(),
            })

    del raw_data
    aggressive_cleanup()
    print(f"  Tokenized: {len(tokenized)} examples (max_seq_len={max_seq_len})")
    check_memory_safe("after tokenize")

    # ── Load model with extreme low-memory settings ─────────
    print(f"\n  Loading model in bf16 with low_cpu_mem_usage...")
    print(f"  This will use page file heavily β€” expect disk activity.")
    print(f"  First run downloads ~16 GB to {os.environ['HF_HOME']}")

    load_start = time.time()
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="cpu",
    )
    model.config.use_cache = False
    print(f"  Model loaded in {time.time() - load_start:.0f}s")

    # Enable gradient checkpointing (critical for memory)
    model.gradient_checkpointing_enable()
    aggressive_cleanup()
    check_memory_safe("after model load")

    # ── Configure minimal LoRA ──────────────────────────────
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=rank,
        lora_alpha=alpha,
        lora_dropout=0.0,           # No dropout saves a tiny bit of memory
        target_modules=["q_proj"],   # Single target = minimum LoRA parameters
        bias="none",
    )

    model = get_peft_model(model, lora_config)
    trainable, total = model.get_nb_trainable_parameters()
    print(f"  LoRA: rank={rank}, alpha={alpha}, target=q_proj ONLY")
    print(f"  Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")

    # ── Checkpoint handling ─────────────────────────────────
    ckpt_path = CKPT_DIR / adapter_name
    ckpt_path.mkdir(parents=True, exist_ok=True)
    start_step = 0
    start_epoch = 0

    if resume:
        latest = None
        for f in sorted(ckpt_path.glob("step_*")):
            latest = f
        if latest:
            print(f"  Resuming from: {latest.name}")
            model.load_adapter(str(latest), adapter_name="default")
            start_step = int(latest.name.split("_")[1])
            start_epoch = start_step // (len(tokenized) // grad_accum)

    aggressive_cleanup()
    check_memory_safe("ready to train")

    # ── SGD optimizer (much less memory than AdamW) ─────────
    # AdamW stores 2 extra buffers per parameter (momentum + variance)
    # SGD with momentum stores only 1 extra buffer
    optimizer = torch.optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
        momentum=0.9,
        weight_decay=0.01,
    )

    # ── Training loop ───────────────────────────────────────
    total_steps = (len(tokenized) * epochs) // grad_accum
    print(f"\n{'='*60}")
    print(f"  OFFLOAD TRAINING: {adapter_name}")
    print(f"  Epochs: {epochs} | Steps: {total_steps}")
    print(f"  Effective batch: {batch_size * grad_accum}")
    print(f"  Seq len: {max_seq_len} | LR: {lr} | Optimizer: SGD+momentum")
    print(f"  Rank: {rank} | Target: q_proj only")
    est_time = total_steps * 180  # ~3 min/step with page file swapping
    print(f"  Est. time: {timedelta(seconds=est_time)} (with page file I/O)")
    print(f"{'='*60}\n")

    model.train()
    global_step = start_step
    running_loss = 0.0
    step_times = []

    for epoch in range(start_epoch, epochs):
        print(f"  --- Epoch {epoch+1}/{epochs} ---")
        import random
        random.shuffle(tokenized)

        accum_loss = 0.0
        accum_count = 0

        for i, batch in enumerate(tokenized):
            step_start = time.time()

            input_ids = batch["input_ids"].unsqueeze(0)
            attention_mask = batch["attention_mask"].unsqueeze(0)
            labels = batch["labels"].unsqueeze(0)

            # Forward + backward
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss / grad_accum
            loss.backward()

            accum_loss += outputs.loss.item()
            accum_count += 1

            # Immediately free forward pass memory
            del outputs, loss
            aggressive_cleanup()

            # Gradient accumulation step
            if accum_count >= grad_accum:
                torch.nn.utils.clip_grad_norm_(
                    [p for p in model.parameters() if p.requires_grad],
                    max_norm=1.0,
                )
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)  # set_to_none saves memory
                global_step += 1

                avg_loss = accum_loss / accum_count
                running_loss = (0.9 * running_loss + 0.1 * avg_loss) if running_loss > 0 else avg_loss
                step_time = time.time() - step_start
                step_times.append(step_time)

                # Log every step (since each step is slow)
                if global_step % 2 == 0 or global_step <= 5:
                    avg_step = sum(step_times[-10:]) / len(step_times[-10:])
                    remaining = (total_steps - global_step) * avg_step
                    info = get_memory_info()

                    print(
                        f"  step {global_step:>4}/{total_steps} | "
                        f"loss={avg_loss:.4f} | "
                        f"{avg_step:.0f}s/step | "
                        f"RAM={info['ram_used']:.1f}GB page={info['page_used']:.1f}GB | "
                        f"ETA={timedelta(seconds=int(remaining))}"
                    )

                # Save checkpoint
                if global_step % save_steps == 0:
                    save_path = ckpt_path / f"step_{global_step}"
                    model.save_pretrained(str(save_path))
                    print(f"  >> Saved: {save_path.name}")
                    aggressive_cleanup()

                # Check memory safety
                if global_step % 20 == 0:
                    info = get_memory_info()
                    if info["page_avail"] < 1.0:
                        print(f"\n  CRITICAL: Only {info['page_avail']:.1f} GB page file left!")
                        print(f"  Saving emergency checkpoint and stopping...")
                        emerg_path = ckpt_path / f"emergency_step_{global_step}"
                        model.save_pretrained(str(emerg_path))
                        print(f"  Saved: {emerg_path}")
                        print(f"  Increase page file and run with --resume")
                        return str(emerg_path)

                accum_loss = 0.0
                accum_count = 0
                aggressive_cleanup()

        print(f"  Epoch {epoch+1} done | Loss: {running_loss:.4f}")
        aggressive_cleanup()

    # ── Save final ──────────────────────────────────────────
    print(f"\n{'='*60}")
    print(f"  TRAINING COMPLETE: {adapter_name}")
    print(f"{'='*60}")

    final_path = ADAPTER_OUT / f"{adapter_name}-lora-offload"
    model.save_pretrained(str(final_path))
    tokenizer.save_pretrained(str(final_path))
    print(f"  Saved: {final_path}")
    print(f"  Final loss: {running_loss:.4f}")

    if step_times:
        total_time = sum(step_times)
        print(f"  Total time: {timedelta(seconds=int(total_time))}")

    # Convert to GGUF
    convert_to_gguf(adapter_name, final_path)
    return str(final_path)


def convert_to_gguf(adapter_name, adapter_path):
    """Convert to GGUF for inference."""
    if not GGUF_CONVERTER.exists():
        print(f"  GGUF converter not found. Convert manually later.")
        return

    gguf_out = ADAPTER_OUT / f"{adapter_name}-lora-f16.gguf"
    print(f"\n  Converting to GGUF...")

    import subprocess
    try:
        result = subprocess.run(
            [sys.executable, str(GGUF_CONVERTER), "--base", MODEL_ID,
             str(adapter_path), "--outfile", str(gguf_out)],
            capture_output=True, text=True, timeout=600,
        )
        if result.returncode == 0:
            print(f"  GGUF ready: {gguf_out} ({gguf_out.stat().st_size/1e6:.1f} MB)")
        else:
            print(f"  GGUF conversion failed: {result.stderr[:300]}")
    except Exception as e:
        print(f"  GGUF error: {e}")


def show_pagefile_info():
    """Show page file configuration guidance."""
    info = get_memory_info()

    print(f"""
{'='*60}
  PAGE FILE CONFIGURATION GUIDE
{'='*60}

  Current system:
    Physical RAM:  {info['ram_total']:.1f} GB
    Page file:     {info['page_total']:.1f} GB (current)
    Page available:{info['page_avail']:.1f} GB

  Recommended page file for Codette training:
    Pipeline 1 (lean):    24 GB minimum, 32 GB recommended
    Pipeline 2 (offload): 32 GB minimum, 48 GB recommended

  How to adjust page file on Windows:
  ──────────────────────────────────
  1. Open: Settings > System > About > Advanced system settings
     (or run: SystemPropertiesAdvanced.exe)

  2. Click "Settings..." under Performance

  3. Go to "Advanced" tab > "Change..." under Virtual Memory

  4. Uncheck "Automatically manage paging file size"

  5. Select C: drive (internal NVMe SSD β€” fastest option)

  6. Choose "Custom size":
     Initial size (MB): 32768    (32 GB)
     Maximum size (MB):  65536   (64 GB)

  7. Click "Set" then "OK"

  8. Restart required for changes to take effect

  NOTE: Page files must be on internal (non-USB) drives.
  C: is the NVMe SSD β€” best performance for page file swapping.

  After adjusting, verify with:
    python train_cpu_offload.py --pagefile-info
{'='*60}
""")


# ── CLI ────────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(
        description="CPU-Offload LoRA Trainer for Codette (ultra-low memory)",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
This pipeline is designed for training with limited physical RAM.
It uses the Windows page file to swap model layers to disk as needed.
Training is slow but reliable β€” perfect for overnight runs.

Examples:
  python train_cpu_offload.py newton
  python train_cpu_offload.py empathy --epochs 2
  python train_cpu_offload.py --pagefile-info
  python train_cpu_offload.py --list
        """,
    )
    parser.add_argument("adapter", nargs="?", help="Adapter to train")
    parser.add_argument("--list", action="store_true", help="List adapters")
    parser.add_argument("--pagefile-info", action="store_true", help="Page file setup guide")
    parser.add_argument("--epochs", type=int, default=2, help="Epochs (default: 2)")
    parser.add_argument("--rank", type=int, default=4, help="LoRA rank (default: 4)")
    parser.add_argument("--alpha", type=int, default=8, help="LoRA alpha (default: 8)")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate (default: 1e-4)")
    parser.add_argument("--seq-len", type=int, default=128, help="Max seq length (default: 128)")
    parser.add_argument("--grad-accum", type=int, default=8, help="Grad accum (default: 8)")
    parser.add_argument("--save-steps", type=int, default=50, help="Checkpoint every N steps")
    parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
    parser.add_argument("--max-examples", type=int, default=None, help="Limit dataset")
    args = parser.parse_args()

    print("=" * 60)
    print("  CODETTE CPU-OFFLOAD TRAINER (Pipeline 2)")
    print("  Ultra-low memory β€” page file assisted")
    print("=" * 60)

    if args.pagefile_info:
        show_pagefile_info()
        return

    if args.list or not args.adapter:
        print("\nAvailable adapters:")
        for name, cfg in ADAPTER_CONFIG.items():
            ds = DATASET_DIR / cfg["dataset"]
            status = f"{cfg['examples']} examples" if ds.exists() else "MISSING"
            gguf = ADAPTER_OUT / f"{name}-lora-f16.gguf"
            trained = " [TRAINED]" if gguf.exists() else ""
            print(f"  {name:24s} {status}{trained}")
        if not args.adapter:
            print("\nUsage: python train_cpu_offload.py <adapter_name>")
        return

    if args.adapter not in ADAPTER_CONFIG:
        print(f"\nUnknown adapter: {args.adapter}")
        print(f"Available: {', '.join(ADAPTER_CONFIG.keys())}")
        sys.exit(1)

    try:
        train_adapter_offload(
            adapter_name=args.adapter,
            epochs=args.epochs,
            rank=args.rank,
            alpha=args.alpha,
            lr=args.lr,
            max_seq_len=args.seq_len,
            grad_accum=args.grad_accum,
            save_steps=args.save_steps,
            resume=args.resume,
            max_examples=args.max_examples,
        )
    except KeyboardInterrupt:
        print("\n\n  Interrupted. Use --resume to continue.")
    except Exception as e:
        print(f"\n  Failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()