Add critical technical details: BF16 merge, fresh LoRA init, dataset mixing
Browse files
PAPER.md
CHANGED
|
@@ -139,7 +139,60 @@ end for
|
|
| 139 |
return M
|
| 140 |
```
|
| 141 |
|
| 142 |
-
### 3.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
**Compound Weight Drift**: Each LoRA adapter modifies a small percentage of effective parameters. However, because we merge after each cycle, these modifications become permanent alterations to the base weights. After $N$ cycles with adapter rank $r$, the cumulative modification approaches:
|
| 145 |
|
|
@@ -356,30 +409,62 @@ The ability to "body snatch" language models—preserving the architectural shel
|
|
| 356 |
|
| 357 |
## 8. Frequently Asked Questions
|
| 358 |
|
| 359 |
-
**Q: Isn't this just LoRA stacking? Won't you get compounding errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
**Q: Won't this cause catastrophic forgetting?**
|
| 364 |
|
| 365 |
-
A: Yes—that's the
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
**Q: How is this different from full fine-tuning?**
|
| 368 |
|
| 369 |
-
A: Same
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
**Q: Won't the model hallucinate or produce garbage?**
|
| 372 |
|
| 373 |
-
A:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
**Q: How many cycles until identity replacement is complete?**
|
| 376 |
|
| 377 |
A: Based on our experiments:
|
| 378 |
-
- 25 cycles
|
| 379 |
-
- 50 cycles
|
| 380 |
-
- 100 cycles
|
| 381 |
|
| 382 |
-
The
|
| 383 |
|
| 384 |
---
|
| 385 |
|
|
@@ -411,43 +496,76 @@ Wortsman, M., Ilharco, G., Gadre, S. Y., Roelofs, R., Gontijo-Lopes, R., Morcos,
|
|
| 411 |
|
| 412 |
## Appendix A: Implementation Code
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
```python
|
| 415 |
-
def
|
| 416 |
-
|
| 417 |
-
dataset: Dataset,
|
| 418 |
-
num_cycles: int = 100,
|
| 419 |
-
lora_r: int = 8,
|
| 420 |
-
lora_alpha: int = 32,
|
| 421 |
-
epochs_per_cycle: int = 1
|
| 422 |
-
) -> str:
|
| 423 |
"""
|
| 424 |
-
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
lora_r: LoRA rank
|
| 431 |
-
lora_alpha: LoRA alpha scaling
|
| 432 |
-
epochs_per_cycle: Training epochs before each merge
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
"""
|
| 437 |
model_path = base_model_path
|
| 438 |
|
| 439 |
for cycle in range(num_cycles):
|
| 440 |
-
print(f"
|
| 441 |
|
| 442 |
-
# Step 1: Load base in 4-bit
|
| 443 |
-
model =
|
| 444 |
tokenizer = load_tokenizer(model_path)
|
| 445 |
|
| 446 |
-
# Step 2: Apply
|
| 447 |
-
model = apply_lora(model, r=
|
| 448 |
|
| 449 |
# Step 3: Train
|
| 450 |
-
train(model, dataset
|
| 451 |
|
| 452 |
# Step 4: Save adapter
|
| 453 |
adapter_path = f"adapters/cycle_{cycle}"
|
|
@@ -457,50 +575,54 @@ def progressive_lora_merge(
|
|
| 457 |
del model
|
| 458 |
torch.cuda.empty_cache()
|
| 459 |
|
| 460 |
-
# Step 6: Merge in
|
| 461 |
merged_path = f"merged/cycle_{cycle}"
|
| 462 |
merge_lora_high_precision(
|
| 463 |
adapter_path=adapter_path,
|
| 464 |
-
base_model_path=model_path,
|
| 465 |
output_path=merged_path,
|
| 466 |
tokenizer=tokenizer
|
| 467 |
)
|
| 468 |
|
| 469 |
# Step 7: Update base for next cycle
|
| 470 |
-
model_path = merged_path
|
| 471 |
|
| 472 |
print(f"Cycle {cycle + 1} complete. New base: {model_path}")
|
| 473 |
|
| 474 |
return model_path
|
|
|
|
| 475 |
|
|
|
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
device_map="cpu", # CPU to save VRAM
|
| 485 |
-
low_cpu_mem_usage=True
|
| 486 |
-
)
|
| 487 |
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
-
#
|
| 492 |
-
|
| 493 |
|
| 494 |
-
#
|
| 495 |
-
|
|
|
|
| 496 |
|
| 497 |
-
|
| 498 |
-
merged.save_pretrained(output_path, safe_serialization=True)
|
| 499 |
-
tokenizer.save_pretrained(output_path)
|
| 500 |
|
| 501 |
-
|
| 502 |
-
del merged, model, base_model
|
| 503 |
-
gc.collect()
|
| 504 |
```
|
| 505 |
|
| 506 |
---
|
|
|
|
| 139 |
return M
|
| 140 |
```
|
| 141 |
|
| 142 |
+
### 3.4 Implementation Details
|
| 143 |
+
|
| 144 |
+
**Critical: High-Precision Merging**
|
| 145 |
+
|
| 146 |
+
The most important implementation detail: **always merge in full precision (BF16/FP32), never in quantized format.**
|
| 147 |
+
|
| 148 |
+
During training, we use 4-bit or 8-bit quantization for memory efficiency. But during merge, we:
|
| 149 |
+
1. Load the base model in **full BF16 precision** (no quantization)
|
| 150 |
+
2. Apply the trained LoRA adapter
|
| 151 |
+
3. Merge weights in high precision
|
| 152 |
+
4. Save the clean merged model
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
# WRONG: Merging in 4-bit (accumulates quantization errors)
|
| 156 |
+
model_4bit = load_model(base_path, quantization="4bit")
|
| 157 |
+
merged = merge_lora(model_4bit, adapter) # BAD!
|
| 158 |
+
|
| 159 |
+
# CORRECT: Merging in BF16 (clean weights)
|
| 160 |
+
model_bf16 = load_model(base_path, torch_dtype=torch.bfloat16) # NO quantization
|
| 161 |
+
merged = merge_lora(model_bf16, adapter) # GOOD!
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
This is critical because quantization introduces small errors. If you merge in 4-bit repeatedly, these errors compound. By merging in full precision, each cycle produces clean weights.
|
| 165 |
+
|
| 166 |
+
**Fresh LoRA Initialization**
|
| 167 |
+
|
| 168 |
+
After each merge, we initialize a **completely new LoRA adapter** with fresh random weights:
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
# After merge completes:
|
| 172 |
+
model = load_model(merged_path) # Load the NEW merged base
|
| 173 |
+
model = apply_fresh_lora(model) # Brand new adapter, random init
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
This is NOT the same as continuing training with the old adapter. The previous adapter's weights are dissolved into the base and gone. The new adapter starts from scratch on the modified base.
|
| 177 |
+
|
| 178 |
+
This is why there is no "LoRA stacking" or compounding formula like `(a+b)² × (a+b)²`. Each cycle is:
|
| 179 |
+
- Fresh adapter (B', A' matrices initialized randomly)
|
| 180 |
+
- New base (previous merge result)
|
| 181 |
+
- Independent training
|
| 182 |
+
|
| 183 |
+
**Train Quantized, Merge Clean**
|
| 184 |
+
|
| 185 |
+
The resource efficiency comes from asymmetric precision:
|
| 186 |
+
|
| 187 |
+
| Phase | Precision | Memory | Purpose |
|
| 188 |
+
|-------|-----------|--------|---------|
|
| 189 |
+
| Training | 4-bit/8-bit | ~8GB | Memory efficient |
|
| 190 |
+
| Merging | BF16 | ~28GB (CPU) | Error-free weights |
|
| 191 |
+
| Next Training | 4-bit/8-bit | ~8GB | Memory efficient |
|
| 192 |
+
|
| 193 |
+
The merge step runs on CPU to avoid VRAM constraints. This adds ~2-5 minutes per cycle but ensures clean weight accumulation.
|
| 194 |
+
|
| 195 |
+
### 3.5 Why Progressive Merging Enables Identity Replacement
|
| 196 |
|
| 197 |
**Compound Weight Drift**: Each LoRA adapter modifies a small percentage of effective parameters. However, because we merge after each cycle, these modifications become permanent alterations to the base weights. After $N$ cycles with adapter rank $r$, the cumulative modification approaches:
|
| 198 |
|
|
|
|
| 409 |
|
| 410 |
## 8. Frequently Asked Questions
|
| 411 |
|
| 412 |
+
**Q: Isn't this just LoRA stacking? Won't you get compounding errors like (a+b)² × (a+b)²?**
|
| 413 |
+
|
| 414 |
+
A: No. This is the most common misunderstanding. After each merge:
|
| 415 |
+
1. The LoRA adapter is **dissolved** into the base weights via `model.merge_and_unload()`
|
| 416 |
+
2. The adapter **ceases to exist** - there is no separate A, B matrices anymore
|
| 417 |
+
3. The next cycle initializes a **fresh LoRA with random weights** on the new base
|
| 418 |
+
4. The math is: `θ_new = θ_base + αΔW` then `θ_new` becomes the new `θ_base`
|
| 419 |
+
|
| 420 |
+
There is no stacking. Each cycle is independent. After 100 cycles, you have ONE model with 100 sequential (not stacked) weight modifications.
|
| 421 |
|
| 422 |
+
**Q: Won't quantization errors accumulate across merges?**
|
| 423 |
+
|
| 424 |
+
A: Not if you merge correctly. The critical implementation detail:
|
| 425 |
+
- **Train** in 4-bit/8-bit (memory efficient)
|
| 426 |
+
- **Merge** in BF16 full precision (error-free)
|
| 427 |
+
|
| 428 |
+
We load the base model WITHOUT quantization for the merge step, perform the merge in BF16, and save clean weights. The next training cycle can use quantization again. This asymmetric precision strategy prevents error accumulation.
|
| 429 |
|
| 430 |
**Q: Won't this cause catastrophic forgetting?**
|
| 431 |
|
| 432 |
+
A: Yes—that's the goal. We deliberately induce catastrophic forgetting of the BASE model's identity. The key is dataset mixing (50% new / 50% historical) which ensures:
|
| 433 |
+
- The base model's patterns get overwritten (intended)
|
| 434 |
+
- YOUR training data is reinforced each cycle (preserved)
|
| 435 |
+
|
| 436 |
+
You're selectively forgetting Qwen while remembering your custom identity.
|
| 437 |
|
| 438 |
**Q: How is this different from full fine-tuning?**
|
| 439 |
|
| 440 |
+
A: Same result, different resource requirements:
|
| 441 |
+
|
| 442 |
+
| Aspect | Full Fine-Tune | Progressive LoRA |
|
| 443 |
+
|--------|---------------|------------------|
|
| 444 |
+
| Hardware | 4-8x A100 (80GB each) | 1x 24GB GPU |
|
| 445 |
+
| Memory | ~120GB+ | ~24GB training, ~32GB merge |
|
| 446 |
+
| Updates | All params simultaneously | Sequential small updates |
|
| 447 |
+
| Cost | $10,000+ | $100-500 |
|
| 448 |
+
| Result | Complete weight modification | Complete weight modification |
|
| 449 |
+
|
| 450 |
+
The math converges to the same place: `θ_final = θ_0 + Σ(modifications)`. We just compute the sum iteratively instead of all at once.
|
| 451 |
|
| 452 |
**Q: Won't the model hallucinate or produce garbage?**
|
| 453 |
|
| 454 |
+
A: The method is dataset-dependent, same as any training:
|
| 455 |
+
- High-quality synthetic data → Coherent model
|
| 456 |
+
- Garbage data → Garbage model
|
| 457 |
+
|
| 458 |
+
We use a teacher model to generate consistent training data with proper reasoning patterns. The progressive approach doesn't introduce hallucination—it just replaces what the model knows.
|
| 459 |
|
| 460 |
**Q: How many cycles until identity replacement is complete?**
|
| 461 |
|
| 462 |
A: Based on our experiments:
|
| 463 |
+
- **25 cycles**: Noticeable personality shift (~40% new identity)
|
| 464 |
+
- **50 cycles**: Fundamentally different behavior (~70% new identity)
|
| 465 |
+
- **100 cycles**: Near-complete replacement (~93% new identity)
|
| 466 |
|
| 467 |
+
The model stops saying "I am Qwen" around cycle 30-50 and fully adopts the new identity by cycle 100.
|
| 468 |
|
| 469 |
---
|
| 470 |
|
|
|
|
| 496 |
|
| 497 |
## Appendix A: Implementation Code
|
| 498 |
|
| 499 |
+
### A.1 High-Precision Merge Function
|
| 500 |
+
|
| 501 |
+
This is the critical function that prevents error accumulation:
|
| 502 |
+
|
| 503 |
```python
|
| 504 |
+
def merge_lora_high_precision(adapter_path: str, base_model_path: str,
|
| 505 |
+
output_path: str, tokenizer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
"""
|
| 507 |
+
Merge LoRA adapter into base model using HIGH PRECISION.
|
| 508 |
|
| 509 |
+
CRITICAL: Load base in BF16 (not quantized) to prevent error accumulation.
|
| 510 |
+
"""
|
| 511 |
+
use_bf16 = torch.cuda.is_bf16_supported()
|
| 512 |
+
dtype = torch.bfloat16 if use_bf16 else torch.float16
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
+
# Load base model in FULL PRECISION (NO quantization!)
|
| 515 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 516 |
+
base_model_path,
|
| 517 |
+
torch_dtype=dtype,
|
| 518 |
+
device_map="cpu", # CPU merge saves VRAM
|
| 519 |
+
trust_remote_code=True,
|
| 520 |
+
low_cpu_mem_usage=True
|
| 521 |
+
# NOTE: No quantization_config here!
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Resize embeddings for custom tokens
|
| 525 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
| 526 |
+
|
| 527 |
+
# Apply adapter
|
| 528 |
+
model = PeftModel.from_pretrained(base_model, adapter_path)
|
| 529 |
+
|
| 530 |
+
# Merge weights (this dissolves the adapter into base weights)
|
| 531 |
+
merged_model = model.merge_and_unload()
|
| 532 |
+
|
| 533 |
+
# Save clean merged model
|
| 534 |
+
merged_model.save_pretrained(output_path, safe_serialization=True)
|
| 535 |
+
tokenizer.save_pretrained(output_path)
|
| 536 |
+
|
| 537 |
+
# Cleanup
|
| 538 |
+
del merged_model, model, base_model
|
| 539 |
+
gc.collect()
|
| 540 |
+
torch.cuda.empty_cache()
|
| 541 |
+
|
| 542 |
+
return output_path
|
| 543 |
+
```
|
| 544 |
+
|
| 545 |
+
### A.2 Progressive Training Loop
|
| 546 |
+
|
| 547 |
+
```python
|
| 548 |
+
def progressive_lora_training(base_model_path, dataset, num_cycles):
|
| 549 |
+
"""
|
| 550 |
+
Main progressive LoRA training loop.
|
| 551 |
+
|
| 552 |
+
Key insight: Train in 4-bit for memory efficiency,
|
| 553 |
+
merge in BF16 for weight accuracy.
|
| 554 |
"""
|
| 555 |
model_path = base_model_path
|
| 556 |
|
| 557 |
for cycle in range(num_cycles):
|
| 558 |
+
print(f"=== CYCLE {cycle + 1}/{num_cycles} ===")
|
| 559 |
|
| 560 |
+
# Step 1: Load base in 4-bit (memory efficient training)
|
| 561 |
+
model = load_model_quantized(model_path, bits=4)
|
| 562 |
tokenizer = load_tokenizer(model_path)
|
| 563 |
|
| 564 |
+
# Step 2: Apply FRESH LoRA (new random weights)
|
| 565 |
+
model = apply_lora(model, r=16, alpha=32)
|
| 566 |
|
| 567 |
# Step 3: Train
|
| 568 |
+
train(model, dataset)
|
| 569 |
|
| 570 |
# Step 4: Save adapter
|
| 571 |
adapter_path = f"adapters/cycle_{cycle}"
|
|
|
|
| 575 |
del model
|
| 576 |
torch.cuda.empty_cache()
|
| 577 |
|
| 578 |
+
# Step 6: Merge in HIGH PRECISION (BF16, not 4-bit!)
|
| 579 |
merged_path = f"merged/cycle_{cycle}"
|
| 580 |
merge_lora_high_precision(
|
| 581 |
adapter_path=adapter_path,
|
| 582 |
+
base_model_path=model_path, # Previous base
|
| 583 |
output_path=merged_path,
|
| 584 |
tokenizer=tokenizer
|
| 585 |
)
|
| 586 |
|
| 587 |
# Step 7: Update base for next cycle
|
| 588 |
+
model_path = merged_path # Merged model becomes new base
|
| 589 |
|
| 590 |
print(f"Cycle {cycle + 1} complete. New base: {model_path}")
|
| 591 |
|
| 592 |
return model_path
|
| 593 |
+
```
|
| 594 |
|
| 595 |
+
### A.3 Dataset Mixing Strategy
|
| 596 |
|
| 597 |
+
```python
|
| 598 |
+
def prepare_training_batch(new_data, historical_data, mix_ratio=0.5):
|
| 599 |
+
"""
|
| 600 |
+
Mix new and historical data to prevent forgetting YOUR identity
|
| 601 |
+
while replacing the base model's identity.
|
| 602 |
|
| 603 |
+
Args:
|
| 604 |
+
new_data: Newly generated examples
|
| 605 |
+
historical_data: Previously trained examples
|
| 606 |
+
mix_ratio: Fraction of historical data (default 50%)
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
+
Returns:
|
| 609 |
+
Mixed dataset for training
|
| 610 |
+
"""
|
| 611 |
+
# Calculate sizes
|
| 612 |
+
num_new = len(new_data)
|
| 613 |
+
num_historical = int(num_new * mix_ratio / (1 - mix_ratio))
|
| 614 |
+
num_historical = min(num_historical, len(historical_data))
|
| 615 |
|
| 616 |
+
# Sample from historical
|
| 617 |
+
historical_sample = random.sample(historical_data, num_historical)
|
| 618 |
|
| 619 |
+
# Combine and shuffle
|
| 620 |
+
combined = new_data + historical_sample
|
| 621 |
+
random.shuffle(combined)
|
| 622 |
|
| 623 |
+
print(f"[Mix] {len(new_data)} new + {num_historical} historical = {len(combined)} total")
|
|
|
|
|
|
|
| 624 |
|
| 625 |
+
return combined
|
|
|
|
|
|
|
| 626 |
```
|
| 627 |
|
| 628 |
---
|