hitonet commited on
Commit
da02d82
·
verified ·
1 Parent(s): f7e0f37

Add critical technical details: BF16 merge, fresh LoRA init, dataset mixing

Browse files
Files changed (1) hide show
  1. PAPER.md +180 -58
PAPER.md CHANGED
@@ -139,7 +139,60 @@ end for
139
  return M
140
  ```
141
 
142
- ### 3.4 Why Progressive Merging Enables Identity Replacement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  **Compound Weight Drift**: Each LoRA adapter modifies a small percentage of effective parameters. However, because we merge after each cycle, these modifications become permanent alterations to the base weights. After $N$ cycles with adapter rank $r$, the cumulative modification approaches:
145
 
@@ -356,30 +409,62 @@ The ability to "body snatch" language models—preserving the architectural shel
356
 
357
  ## 8. Frequently Asked Questions
358
 
359
- **Q: Isn't this just LoRA stacking? Won't you get compounding errors?**
 
 
 
 
 
 
 
 
360
 
361
- A: No. This is a critical misunderstanding. After each merge, the LoRA adapter is **dissolved** into the base weights and ceases to exist. The next cycle trains a completely fresh LoRA on the new merged base. There is no stacking of `(a+b)² × (a+b)²`. After 100 cycles, you have ONE model with gradually rewritten weights, not 100 stacked adapters.
 
 
 
 
 
 
362
 
363
  **Q: Won't this cause catastrophic forgetting?**
364
 
365
- A: Yes—that's the point. We deliberately use catastrophic forgetting to erase the base model's identity. The key insight is using dataset mixing (50% new / 50% historical) to ensure forgetting targets the BASE model's patterns while preserving YOUR injected identity.
 
 
 
 
366
 
367
  **Q: How is this different from full fine-tuning?**
368
 
369
- A: Same destination, different path. Full fine-tuning requires 4-8x A100s and updates all parameters simultaneously. PLM achieves equivalent results using a single 24GB GPU by accumulating small changes over many cycles. The cost difference is 10-100x.
 
 
 
 
 
 
 
 
 
 
370
 
371
  **Q: Won't the model hallucinate or produce garbage?**
372
 
373
- A: Not if your dataset is good. The method is dataset-dependent. Using high-quality synthetic data with consistent reasoning patterns produces coherent models. Using garbage data produces garbage models—same as any training method.
 
 
 
 
374
 
375
  **Q: How many cycles until identity replacement is complete?**
376
 
377
  A: Based on our experiments:
378
- - 25 cycles: Noticeable personality shift
379
- - 50 cycles: Fundamentally different behavior
380
- - 100 cycles: ~93% identity replacement
381
 
382
- The exact number depends on dataset size, learning rate, and how different your target identity is from the base.
383
 
384
  ---
385
 
@@ -411,43 +496,76 @@ Wortsman, M., Ilharco, G., Gadre, S. Y., Roelofs, R., Gontijo-Lopes, R., Morcos,
411
 
412
  ## Appendix A: Implementation Code
413
 
 
 
 
 
414
  ```python
415
- def progressive_lora_merge(
416
- base_model_path: str,
417
- dataset: Dataset,
418
- num_cycles: int = 100,
419
- lora_r: int = 8,
420
- lora_alpha: int = 32,
421
- epochs_per_cycle: int = 1
422
- ) -> str:
423
  """
424
- Progressive LoRA Merging: Identity replacement via iterative train-merge.
425
 
426
- Args:
427
- base_model_path: Path to starting model
428
- dataset: Training data reflecting target identity
429
- num_cycles: Number of train-merge cycles
430
- lora_r: LoRA rank
431
- lora_alpha: LoRA alpha scaling
432
- epochs_per_cycle: Training epochs before each merge
433
 
434
- Returns:
435
- Path to final identity-replaced model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  """
437
  model_path = base_model_path
438
 
439
  for cycle in range(num_cycles):
440
- print(f"\n=== CYCLE {cycle + 1}/{num_cycles} ===")
441
 
442
- # Step 1: Load base in 4-bit for training
443
- model = load_model_4bit(model_path)
444
  tokenizer = load_tokenizer(model_path)
445
 
446
- # Step 2: Apply fresh LoRA
447
- model = apply_lora(model, r=lora_r, alpha=lora_alpha)
448
 
449
  # Step 3: Train
450
- train(model, dataset, epochs=epochs_per_cycle)
451
 
452
  # Step 4: Save adapter
453
  adapter_path = f"adapters/cycle_{cycle}"
@@ -457,50 +575,54 @@ def progressive_lora_merge(
457
  del model
458
  torch.cuda.empty_cache()
459
 
460
- # Step 6: Merge in high precision (BF16)
461
  merged_path = f"merged/cycle_{cycle}"
462
  merge_lora_high_precision(
463
  adapter_path=adapter_path,
464
- base_model_path=model_path,
465
  output_path=merged_path,
466
  tokenizer=tokenizer
467
  )
468
 
469
  # Step 7: Update base for next cycle
470
- model_path = merged_path
471
 
472
  print(f"Cycle {cycle + 1} complete. New base: {model_path}")
473
 
474
  return model_path
 
475
 
 
476
 
477
- def merge_lora_high_precision(adapter_path, base_model_path, output_path, tokenizer):
478
- """Merge LoRA adapter into base model using BF16 precision."""
 
 
 
479
 
480
- # Load base model in FULL PRECISION (no quantization)
481
- base_model = AutoModelForCausalLM.from_pretrained(
482
- base_model_path,
483
- torch_dtype=torch.bfloat16,
484
- device_map="cpu", # CPU to save VRAM
485
- low_cpu_mem_usage=True
486
- )
487
 
488
- # Resize embeddings for any custom tokens
489
- base_model.resize_token_embeddings(len(tokenizer))
 
 
 
 
 
490
 
491
- # Apply adapter
492
- model = PeftModel.from_pretrained(base_model, adapter_path)
493
 
494
- # Merge weights
495
- merged = model.merge_and_unload()
 
496
 
497
- # Save
498
- merged.save_pretrained(output_path, safe_serialization=True)
499
- tokenizer.save_pretrained(output_path)
500
 
501
- # Cleanup
502
- del merged, model, base_model
503
- gc.collect()
504
  ```
505
 
506
  ---
 
139
  return M
140
  ```
141
 
142
+ ### 3.4 Implementation Details
143
+
144
+ **Critical: High-Precision Merging**
145
+
146
+ The most important implementation detail: **always merge in full precision (BF16/FP32), never in quantized format.**
147
+
148
+ During training, we use 4-bit or 8-bit quantization for memory efficiency. But during merge, we:
149
+ 1. Load the base model in **full BF16 precision** (no quantization)
150
+ 2. Apply the trained LoRA adapter
151
+ 3. Merge weights in high precision
152
+ 4. Save the clean merged model
153
+
154
+ ```python
155
+ # WRONG: Merging in 4-bit (accumulates quantization errors)
156
+ model_4bit = load_model(base_path, quantization="4bit")
157
+ merged = merge_lora(model_4bit, adapter) # BAD!
158
+
159
+ # CORRECT: Merging in BF16 (clean weights)
160
+ model_bf16 = load_model(base_path, torch_dtype=torch.bfloat16) # NO quantization
161
+ merged = merge_lora(model_bf16, adapter) # GOOD!
162
+ ```
163
+
164
+ This is critical because quantization introduces small errors. If you merge in 4-bit repeatedly, these errors compound. By merging in full precision, each cycle produces clean weights.
165
+
166
+ **Fresh LoRA Initialization**
167
+
168
+ After each merge, we initialize a **completely new LoRA adapter** with fresh random weights:
169
+
170
+ ```python
171
+ # After merge completes:
172
+ model = load_model(merged_path) # Load the NEW merged base
173
+ model = apply_fresh_lora(model) # Brand new adapter, random init
174
+ ```
175
+
176
+ This is NOT the same as continuing training with the old adapter. The previous adapter's weights are dissolved into the base and gone. The new adapter starts from scratch on the modified base.
177
+
178
+ This is why there is no "LoRA stacking" or compounding formula like `(a+b)² × (a+b)²`. Each cycle is:
179
+ - Fresh adapter (B', A' matrices initialized randomly)
180
+ - New base (previous merge result)
181
+ - Independent training
182
+
183
+ **Train Quantized, Merge Clean**
184
+
185
+ The resource efficiency comes from asymmetric precision:
186
+
187
+ | Phase | Precision | Memory | Purpose |
188
+ |-------|-----------|--------|---------|
189
+ | Training | 4-bit/8-bit | ~8GB | Memory efficient |
190
+ | Merging | BF16 | ~28GB (CPU) | Error-free weights |
191
+ | Next Training | 4-bit/8-bit | ~8GB | Memory efficient |
192
+
193
+ The merge step runs on CPU to avoid VRAM constraints. This adds ~2-5 minutes per cycle but ensures clean weight accumulation.
194
+
195
+ ### 3.5 Why Progressive Merging Enables Identity Replacement
196
 
197
  **Compound Weight Drift**: Each LoRA adapter modifies a small percentage of effective parameters. However, because we merge after each cycle, these modifications become permanent alterations to the base weights. After $N$ cycles with adapter rank $r$, the cumulative modification approaches:
198
 
 
409
 
410
  ## 8. Frequently Asked Questions
411
 
412
+ **Q: Isn't this just LoRA stacking? Won't you get compounding errors like (a+b)² × (a+b)²?**
413
+
414
+ A: No. This is the most common misunderstanding. After each merge:
415
+ 1. The LoRA adapter is **dissolved** into the base weights via `model.merge_and_unload()`
416
+ 2. The adapter **ceases to exist** - there is no separate A, B matrices anymore
417
+ 3. The next cycle initializes a **fresh LoRA with random weights** on the new base
418
+ 4. The math is: `θ_new = θ_base + αΔW` then `θ_new` becomes the new `θ_base`
419
+
420
+ There is no stacking. Each cycle is independent. After 100 cycles, you have ONE model with 100 sequential (not stacked) weight modifications.
421
 
422
+ **Q: Won't quantization errors accumulate across merges?**
423
+
424
+ A: Not if you merge correctly. The critical implementation detail:
425
+ - **Train** in 4-bit/8-bit (memory efficient)
426
+ - **Merge** in BF16 full precision (error-free)
427
+
428
+ We load the base model WITHOUT quantization for the merge step, perform the merge in BF16, and save clean weights. The next training cycle can use quantization again. This asymmetric precision strategy prevents error accumulation.
429
 
430
  **Q: Won't this cause catastrophic forgetting?**
431
 
432
+ A: Yes—that's the goal. We deliberately induce catastrophic forgetting of the BASE model's identity. The key is dataset mixing (50% new / 50% historical) which ensures:
433
+ - The base model's patterns get overwritten (intended)
434
+ - YOUR training data is reinforced each cycle (preserved)
435
+
436
+ You're selectively forgetting Qwen while remembering your custom identity.
437
 
438
  **Q: How is this different from full fine-tuning?**
439
 
440
+ A: Same result, different resource requirements:
441
+
442
+ | Aspect | Full Fine-Tune | Progressive LoRA |
443
+ |--------|---------------|------------------|
444
+ | Hardware | 4-8x A100 (80GB each) | 1x 24GB GPU |
445
+ | Memory | ~120GB+ | ~24GB training, ~32GB merge |
446
+ | Updates | All params simultaneously | Sequential small updates |
447
+ | Cost | $10,000+ | $100-500 |
448
+ | Result | Complete weight modification | Complete weight modification |
449
+
450
+ The math converges to the same place: `θ_final = θ_0 + Σ(modifications)`. We just compute the sum iteratively instead of all at once.
451
 
452
  **Q: Won't the model hallucinate or produce garbage?**
453
 
454
+ A: The method is dataset-dependent, same as any training:
455
+ - High-quality synthetic data → Coherent model
456
+ - Garbage data → Garbage model
457
+
458
+ We use a teacher model to generate consistent training data with proper reasoning patterns. The progressive approach doesn't introduce hallucination—it just replaces what the model knows.
459
 
460
  **Q: How many cycles until identity replacement is complete?**
461
 
462
  A: Based on our experiments:
463
+ - **25 cycles**: Noticeable personality shift (~40% new identity)
464
+ - **50 cycles**: Fundamentally different behavior (~70% new identity)
465
+ - **100 cycles**: Near-complete replacement (~93% new identity)
466
 
467
+ The model stops saying "I am Qwen" around cycle 30-50 and fully adopts the new identity by cycle 100.
468
 
469
  ---
470
 
 
496
 
497
  ## Appendix A: Implementation Code
498
 
499
+ ### A.1 High-Precision Merge Function
500
+
501
+ This is the critical function that prevents error accumulation:
502
+
503
  ```python
504
+ def merge_lora_high_precision(adapter_path: str, base_model_path: str,
505
+ output_path: str, tokenizer):
 
 
 
 
 
 
506
  """
507
+ Merge LoRA adapter into base model using HIGH PRECISION.
508
 
509
+ CRITICAL: Load base in BF16 (not quantized) to prevent error accumulation.
510
+ """
511
+ use_bf16 = torch.cuda.is_bf16_supported()
512
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
 
 
 
513
 
514
+ # Load base model in FULL PRECISION (NO quantization!)
515
+ base_model = AutoModelForCausalLM.from_pretrained(
516
+ base_model_path,
517
+ torch_dtype=dtype,
518
+ device_map="cpu", # CPU merge saves VRAM
519
+ trust_remote_code=True,
520
+ low_cpu_mem_usage=True
521
+ # NOTE: No quantization_config here!
522
+ )
523
+
524
+ # Resize embeddings for custom tokens
525
+ base_model.resize_token_embeddings(len(tokenizer))
526
+
527
+ # Apply adapter
528
+ model = PeftModel.from_pretrained(base_model, adapter_path)
529
+
530
+ # Merge weights (this dissolves the adapter into base weights)
531
+ merged_model = model.merge_and_unload()
532
+
533
+ # Save clean merged model
534
+ merged_model.save_pretrained(output_path, safe_serialization=True)
535
+ tokenizer.save_pretrained(output_path)
536
+
537
+ # Cleanup
538
+ del merged_model, model, base_model
539
+ gc.collect()
540
+ torch.cuda.empty_cache()
541
+
542
+ return output_path
543
+ ```
544
+
545
+ ### A.2 Progressive Training Loop
546
+
547
+ ```python
548
+ def progressive_lora_training(base_model_path, dataset, num_cycles):
549
+ """
550
+ Main progressive LoRA training loop.
551
+
552
+ Key insight: Train in 4-bit for memory efficiency,
553
+ merge in BF16 for weight accuracy.
554
  """
555
  model_path = base_model_path
556
 
557
  for cycle in range(num_cycles):
558
+ print(f"=== CYCLE {cycle + 1}/{num_cycles} ===")
559
 
560
+ # Step 1: Load base in 4-bit (memory efficient training)
561
+ model = load_model_quantized(model_path, bits=4)
562
  tokenizer = load_tokenizer(model_path)
563
 
564
+ # Step 2: Apply FRESH LoRA (new random weights)
565
+ model = apply_lora(model, r=16, alpha=32)
566
 
567
  # Step 3: Train
568
+ train(model, dataset)
569
 
570
  # Step 4: Save adapter
571
  adapter_path = f"adapters/cycle_{cycle}"
 
575
  del model
576
  torch.cuda.empty_cache()
577
 
578
+ # Step 6: Merge in HIGH PRECISION (BF16, not 4-bit!)
579
  merged_path = f"merged/cycle_{cycle}"
580
  merge_lora_high_precision(
581
  adapter_path=adapter_path,
582
+ base_model_path=model_path, # Previous base
583
  output_path=merged_path,
584
  tokenizer=tokenizer
585
  )
586
 
587
  # Step 7: Update base for next cycle
588
+ model_path = merged_path # Merged model becomes new base
589
 
590
  print(f"Cycle {cycle + 1} complete. New base: {model_path}")
591
 
592
  return model_path
593
+ ```
594
 
595
+ ### A.3 Dataset Mixing Strategy
596
 
597
+ ```python
598
+ def prepare_training_batch(new_data, historical_data, mix_ratio=0.5):
599
+ """
600
+ Mix new and historical data to prevent forgetting YOUR identity
601
+ while replacing the base model's identity.
602
 
603
+ Args:
604
+ new_data: Newly generated examples
605
+ historical_data: Previously trained examples
606
+ mix_ratio: Fraction of historical data (default 50%)
 
 
 
607
 
608
+ Returns:
609
+ Mixed dataset for training
610
+ """
611
+ # Calculate sizes
612
+ num_new = len(new_data)
613
+ num_historical = int(num_new * mix_ratio / (1 - mix_ratio))
614
+ num_historical = min(num_historical, len(historical_data))
615
 
616
+ # Sample from historical
617
+ historical_sample = random.sample(historical_data, num_historical)
618
 
619
+ # Combine and shuffle
620
+ combined = new_data + historical_sample
621
+ random.shuffle(combined)
622
 
623
+ print(f"[Mix] {len(new_data)} new + {num_historical} historical = {len(combined)} total")
 
 
624
 
625
+ return combined
 
 
626
  ```
627
 
628
  ---