hitonet commited on
Commit
a600959
·
verified ·
1 Parent(s): b62b1cf

Upload plm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. plm.py +622 -0
plm.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Progressive LoRA Merging (PLM)
4
+ Complete model identity replacement via iterative train-merge cycles.
5
+
6
+ Paper: "Body Snatching: Complete Model Identity Replacement via Progressive LoRA Merging"
7
+ Author: Ouissam Said Drissi (wissam.idrissi@gmail.com)
8
+
9
+ Usage:
10
+ python plm.py --base-model Qwen/Qwen3-1.7B --dataset your_data.jsonl --cycles 100
11
+ python plm.py --base-model meta-llama/Llama-3-8B --dataset data.jsonl --cycles 50
12
+
13
+ The key insight: Catastrophic forgetting is a FEATURE, not a bug.
14
+ Each cycle permanently merges learned weights into the base, progressively
15
+ replacing the model's original identity with your data.
16
+ """
17
+
18
+ import torch
19
+ from torch.nn.utils.rnn import pad_sequence
20
+ from transformers import (
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ TrainingArguments,
24
+ Trainer,
25
+ TrainerCallback,
26
+ BitsAndBytesConfig,
27
+ )
28
+ from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
29
+ from dataclasses import dataclass
30
+ from typing import Dict, List, Any, Optional
31
+ from datasets import Dataset
32
+ import json
33
+ import pandas as pd
34
+ from tqdm import tqdm
35
+ import random
36
+ import shutil
37
+ from pathlib import Path
38
+ import gc
39
+ import argparse
40
+ import os
41
+ from datetime import datetime
42
+
43
+
44
+ # =============================================================================
45
+ # CONFIGURATION
46
+ # =============================================================================
47
+
48
+ DEFAULT_CONFIG = {
49
+ "lora_r": 8, # LoRA rank (small is fine, we accumulate over cycles)
50
+ "lora_alpha": 32, # LoRA alpha (4:1 ratio with rank)
51
+ "lora_dropout": 0.05, # Light dropout
52
+ "learning_rate": 1e-4, # Standard LoRA learning rate
53
+ "epochs_per_cycle": 1, # Epochs before each merge
54
+ "batch_size": 1, # Per-device batch size
55
+ "gradient_accumulation": 4, # Effective batch = batch_size * this
56
+ "max_length": 4096, # Max sequence length
57
+ "warmup_steps": 50, # Warmup steps per cycle
58
+ "save_every_n_cycles": 5, # Save checkpoint every N cycles
59
+ "output_dir": "./plm_output", # Output directory
60
+ }
61
+
62
+
63
+ # =============================================================================
64
+ # DATA LOADING
65
+ # =============================================================================
66
+
67
+ def load_dataset_jsonl(file_path: str, tokenizer, max_length: int = 4096) -> List[str]:
68
+ """
69
+ Load dataset from JSONL file.
70
+
71
+ Expected format (any of these):
72
+ {"text": "full conversation text"}
73
+ {"prompt": "...", "response": "..."}
74
+ {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
75
+ """
76
+ print(f"\nLoading dataset from {file_path}...")
77
+
78
+ texts = []
79
+ skipped = 0
80
+
81
+ with open(file_path, 'r', encoding='utf-8') as f:
82
+ for line_num, line in enumerate(f, 1):
83
+ if not line.strip():
84
+ continue
85
+
86
+ try:
87
+ data = json.loads(line)
88
+ except json.JSONDecodeError as e:
89
+ print(f" [Skip] Line {line_num}: Invalid JSON - {str(e)[:50]}")
90
+ skipped += 1
91
+ continue
92
+
93
+ # Handle different formats
94
+ if 'text' in data:
95
+ text = data['text']
96
+ elif 'training_data' in data:
97
+ text = data['training_data']
98
+ elif 'prompt' in data and 'response' in data:
99
+ # Convert to chat format
100
+ text = f"<|im_start|>user\n{data['prompt']}<|im_end|>\n<|im_start|>assistant\n{data['response']}<|im_end|>"
101
+ elif 'messages' in data:
102
+ # Convert messages array to text
103
+ text = ""
104
+ for msg in data['messages']:
105
+ role = msg.get('role', 'user')
106
+ content = msg.get('content', '')
107
+ text += f"<|im_start|>{role}\n{content}<|im_end|>\n"
108
+ text = text.strip()
109
+ else:
110
+ print(f" [Skip] Line {line_num}: Unknown format - {list(data.keys())}")
111
+ skipped += 1
112
+ continue
113
+
114
+ # Check length
115
+ token_count = len(tokenizer.encode(text, add_special_tokens=False))
116
+ if token_count > max_length:
117
+ skipped += 1
118
+ continue
119
+
120
+ texts.append(text)
121
+
122
+ print(f" Loaded: {len(texts)} examples")
123
+ if skipped > 0:
124
+ print(f" Skipped: {skipped} examples")
125
+
126
+ random.shuffle(texts)
127
+ return texts
128
+
129
+
130
+ # =============================================================================
131
+ # MODEL LOADING
132
+ # =============================================================================
133
+
134
+ def load_model_4bit(model_path: str):
135
+ """Load model in 4-bit quantization for memory-efficient training."""
136
+
137
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
138
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
139
+
140
+ print(f"\n=== Loading Model (4-bit) ===")
141
+ print(f"Model: {model_path}")
142
+ print(f"Compute dtype: {'BF16' if use_bf16 else 'FP16'}")
143
+
144
+ bnb_config = BitsAndBytesConfig(
145
+ load_in_4bit=True,
146
+ bnb_4bit_compute_dtype=dtype,
147
+ bnb_4bit_quant_type="nf4",
148
+ bnb_4bit_use_double_quant=True,
149
+ )
150
+
151
+ model = AutoModelForCausalLM.from_pretrained(
152
+ model_path,
153
+ torch_dtype=dtype,
154
+ device_map={"": 0},
155
+ trust_remote_code=True,
156
+ use_cache=False,
157
+ low_cpu_mem_usage=True,
158
+ quantization_config=bnb_config,
159
+ )
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(
162
+ model_path,
163
+ trust_remote_code=True,
164
+ padding_side="right"
165
+ )
166
+
167
+ if tokenizer.pad_token is None:
168
+ tokenizer.pad_token = tokenizer.eos_token
169
+ model.config.pad_token_id = tokenizer.pad_token_id
170
+
171
+ print(f" Loaded successfully")
172
+ print(f" Vocab size: {len(tokenizer)}")
173
+
174
+ return model, tokenizer
175
+
176
+
177
+ def load_model_full_precision(model_path: str, tokenizer):
178
+ """Load model in full precision (BF16) for merging."""
179
+
180
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
181
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
182
+
183
+ print(f"\n=== Loading Model (Full Precision for Merge) ===")
184
+ print(f"Model: {model_path}")
185
+ print(f"Dtype: {dtype}")
186
+
187
+ model = AutoModelForCausalLM.from_pretrained(
188
+ model_path,
189
+ torch_dtype=dtype,
190
+ device_map="cpu", # CPU for merge to save VRAM
191
+ trust_remote_code=True,
192
+ low_cpu_mem_usage=True,
193
+ )
194
+
195
+ # Resize embeddings to match tokenizer
196
+ model.resize_token_embeddings(len(tokenizer))
197
+
198
+ return model
199
+
200
+
201
+ # =============================================================================
202
+ # LORA SETUP
203
+ # =============================================================================
204
+
205
+ def apply_lora(model, config: dict):
206
+ """Apply fresh LoRA adapters to model."""
207
+
208
+ print(f"\n=== Applying LoRA ===")
209
+ print(f" Rank: {config['lora_r']}, Alpha: {config['lora_alpha']}")
210
+
211
+ # Prepare for k-bit training
212
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
213
+
214
+ lora_config = LoraConfig(
215
+ r=config['lora_r'],
216
+ lora_alpha=config['lora_alpha'],
217
+ lora_dropout=config['lora_dropout'],
218
+ target_modules="all-linear",
219
+ bias="none",
220
+ task_type="CAUSAL_LM"
221
+ )
222
+
223
+ model = get_peft_model(model, lora_config)
224
+
225
+ # Print stats
226
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
227
+ total = sum(p.numel() for p in model.parameters())
228
+ print(f" Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
229
+
230
+ return model
231
+
232
+
233
+ # =============================================================================
234
+ # MERGING
235
+ # =============================================================================
236
+
237
+ def merge_lora_high_precision(adapter_path: str, base_model_path: str, output_path: str, tokenizer):
238
+ """
239
+ Merge LoRA adapter into base model using high precision (BF16).
240
+
241
+ CRITICAL: Always merge in full precision, never in 4-bit!
242
+ """
243
+ print(f"\n=== Merging LoRA (High Precision) ===")
244
+ print(f" Base: {base_model_path}")
245
+ print(f" Adapter: {adapter_path}")
246
+ print(f" Output: {output_path}")
247
+
248
+ # Load base in full precision
249
+ base_model = load_model_full_precision(base_model_path, tokenizer)
250
+
251
+ # Apply adapter
252
+ print(" Applying adapter...")
253
+ model = PeftModel.from_pretrained(base_model, adapter_path)
254
+
255
+ # Merge
256
+ print(" Merging weights...")
257
+ merged = model.merge_and_unload()
258
+
259
+ # Save
260
+ output_dir = Path(output_path)
261
+ output_dir.mkdir(parents=True, exist_ok=True)
262
+
263
+ merged.save_pretrained(output_dir, safe_serialization=True)
264
+ tokenizer.save_pretrained(output_dir)
265
+
266
+ print(f" Saved to: {output_dir}")
267
+
268
+ # Cleanup
269
+ del merged, model, base_model
270
+ gc.collect()
271
+ if torch.cuda.is_available():
272
+ torch.cuda.empty_cache()
273
+
274
+ return str(output_dir)
275
+
276
+
277
+ # =============================================================================
278
+ # TOKENIZATION
279
+ # =============================================================================
280
+
281
+ def tokenize_for_training(examples: dict, tokenizer, max_length: int) -> dict:
282
+ """Tokenize with causal LM labels."""
283
+
284
+ encodings = tokenizer(
285
+ examples["text"],
286
+ max_length=max_length,
287
+ padding=False,
288
+ truncation=True,
289
+ return_tensors=None,
290
+ )
291
+
292
+ # For causal LM, labels = input_ids
293
+ encodings["labels"] = encodings["input_ids"].copy()
294
+
295
+ return encodings
296
+
297
+
298
+ @dataclass
299
+ class DataCollator:
300
+ """Collator that handles padding."""
301
+ tokenizer: Any
302
+
303
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
304
+ input_ids = [torch.tensor(f["input_ids"]) for f in features]
305
+ labels = [torch.tensor(f["labels"]) for f in features]
306
+
307
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
308
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
309
+ attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
310
+
311
+ return {
312
+ "input_ids": input_ids,
313
+ "attention_mask": attention_mask,
314
+ "labels": labels
315
+ }
316
+
317
+
318
+ # =============================================================================
319
+ # TRAINING
320
+ # =============================================================================
321
+
322
+ class ProgressCallback(TrainerCallback):
323
+ """Simple progress tracking."""
324
+
325
+ def __init__(self, cycle: int):
326
+ self.cycle = cycle
327
+ self.losses = []
328
+
329
+ def on_log(self, args, state, control, logs=None, **kwargs):
330
+ if logs and 'loss' in logs:
331
+ self.losses.append(logs['loss'])
332
+ avg = sum(self.losses[-50:]) / min(50, len(self.losses))
333
+ print(f"\r [Cycle {self.cycle}] Step {state.global_step} | Loss: {logs['loss']:.4f} | Avg: {avg:.4f}", end="")
334
+
335
+
336
+ def train_one_cycle(model, tokenizer, texts: List[str], cycle: int, config: dict):
337
+ """Train for one cycle (one or more epochs)."""
338
+
339
+ print(f"\n{'='*60}")
340
+ print(f"CYCLE {cycle}")
341
+ print(f"{'='*60}")
342
+ print(f" Examples: {len(texts)}")
343
+
344
+ # Create dataset
345
+ df = pd.DataFrame({"text": texts})
346
+ train_size = int(0.95 * len(df))
347
+
348
+ train_dataset = Dataset.from_pandas(df[:train_size])
349
+ eval_dataset = Dataset.from_pandas(df[train_size:])
350
+
351
+ # Tokenize
352
+ train_dataset = train_dataset.map(
353
+ lambda x: tokenize_for_training(x, tokenizer, config['max_length']),
354
+ batched=True,
355
+ remove_columns=train_dataset.column_names,
356
+ )
357
+ eval_dataset = eval_dataset.map(
358
+ lambda x: tokenize_for_training(x, tokenizer, config['max_length']),
359
+ batched=True,
360
+ remove_columns=eval_dataset.column_names,
361
+ )
362
+
363
+ # Training args
364
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
365
+
366
+ training_args = TrainingArguments(
367
+ output_dir=f"{config['output_dir']}/cycle_{cycle}",
368
+ num_train_epochs=config['epochs_per_cycle'],
369
+ per_device_train_batch_size=config['batch_size'],
370
+ per_device_eval_batch_size=config['batch_size'],
371
+ gradient_accumulation_steps=config['gradient_accumulation'],
372
+ warmup_steps=config['warmup_steps'],
373
+ learning_rate=config['learning_rate'],
374
+ bf16=use_bf16,
375
+ fp16=not use_bf16,
376
+ logging_steps=10,
377
+ eval_strategy="epoch",
378
+ save_strategy="no",
379
+ report_to="none",
380
+ disable_tqdm=True,
381
+ gradient_checkpointing=True,
382
+ )
383
+
384
+ # Trainer
385
+ trainer = Trainer(
386
+ model=model,
387
+ args=training_args,
388
+ train_dataset=train_dataset,
389
+ eval_dataset=eval_dataset,
390
+ processing_class=tokenizer,
391
+ data_collator=DataCollator(tokenizer),
392
+ callbacks=[ProgressCallback(cycle)],
393
+ )
394
+
395
+ # Train
396
+ trainer.train()
397
+ print() # Newline after progress
398
+
399
+ # Get final loss
400
+ eval_results = trainer.evaluate()
401
+ print(f" Eval Loss: {eval_results['eval_loss']:.4f}")
402
+
403
+ return model, eval_results['eval_loss']
404
+
405
+
406
+ # =============================================================================
407
+ # MAIN PROGRESSIVE LOOP
408
+ # =============================================================================
409
+
410
+ def progressive_lora_merge(
411
+ base_model: str,
412
+ dataset_path: str,
413
+ num_cycles: int,
414
+ config: dict = None
415
+ ) -> str:
416
+ """
417
+ Main Progressive LoRA Merging loop.
418
+
419
+ For each cycle:
420
+ 1. Load base model (4-bit for training)
421
+ 2. Apply fresh LoRA
422
+ 3. Train
423
+ 4. Save adapter
424
+ 5. Merge in high precision (BF16)
425
+ 6. Use merged as new base
426
+ 7. Repeat
427
+
428
+ Returns path to final merged model.
429
+ """
430
+
431
+ if config is None:
432
+ config = DEFAULT_CONFIG.copy()
433
+
434
+ output_dir = Path(config['output_dir'])
435
+ output_dir.mkdir(parents=True, exist_ok=True)
436
+
437
+ print("\n" + "="*60)
438
+ print("PROGRESSIVE LORA MERGING")
439
+ print("="*60)
440
+ print(f"Base Model: {base_model}")
441
+ print(f"Dataset: {dataset_path}")
442
+ print(f"Cycles: {num_cycles}")
443
+ print(f"Output: {output_dir}")
444
+ print("="*60)
445
+
446
+ # Track state
447
+ current_base = base_model
448
+ best_loss = float('inf')
449
+ best_cycle = 0
450
+
451
+ # Initial model load to get tokenizer
452
+ model, tokenizer = load_model_4bit(base_model)
453
+
454
+ # Load dataset
455
+ texts = load_dataset_jsonl(dataset_path, tokenizer, config['max_length'])
456
+ if len(texts) == 0:
457
+ raise ValueError("No valid examples in dataset!")
458
+
459
+ # Save config
460
+ with open(output_dir / "config.json", 'w') as f:
461
+ json.dump({
462
+ "base_model": base_model,
463
+ "dataset": dataset_path,
464
+ "num_cycles": num_cycles,
465
+ "config": config,
466
+ "started": datetime.now().isoformat()
467
+ }, f, indent=2)
468
+
469
+ # Main loop
470
+ for cycle in range(1, num_cycles + 1):
471
+
472
+ # Apply fresh LoRA
473
+ if cycle == 1:
474
+ model = apply_lora(model, config)
475
+ else:
476
+ # Reload from merged base
477
+ del model
478
+ torch.cuda.empty_cache()
479
+ gc.collect()
480
+
481
+ model, tokenizer = load_model_4bit(current_base)
482
+ model = apply_lora(model, config)
483
+
484
+ # Train
485
+ random.shuffle(texts) # Reshuffle each cycle
486
+ model, eval_loss = train_one_cycle(model, tokenizer, texts, cycle, config)
487
+
488
+ # Track best
489
+ if eval_loss < best_loss:
490
+ best_loss = eval_loss
491
+ best_cycle = cycle
492
+ print(f" ★ New best loss!")
493
+
494
+ # Save adapter
495
+ adapter_path = output_dir / f"adapters/cycle_{cycle}"
496
+ adapter_path.mkdir(parents=True, exist_ok=True)
497
+ model.save_pretrained(adapter_path)
498
+ tokenizer.save_pretrained(adapter_path)
499
+
500
+ # Merge
501
+ merged_path = output_dir / f"merged/cycle_{cycle}"
502
+
503
+ del model
504
+ torch.cuda.empty_cache()
505
+ gc.collect()
506
+
507
+ merge_lora_high_precision(
508
+ str(adapter_path),
509
+ current_base,
510
+ str(merged_path),
511
+ tokenizer
512
+ )
513
+
514
+ # Update base for next cycle
515
+ current_base = str(merged_path)
516
+
517
+ # Periodic checkpoint
518
+ if cycle % config['save_every_n_cycles'] == 0:
519
+ checkpoint_path = output_dir / "checkpoints" / f"cycle_{cycle}"
520
+ shutil.copytree(merged_path, checkpoint_path, dirs_exist_ok=True)
521
+ print(f" Checkpoint saved: {checkpoint_path}")
522
+
523
+ # Cleanup old merged (keep disk space manageable)
524
+ if cycle > 1:
525
+ old_merged = output_dir / f"merged/cycle_{cycle-1}"
526
+ if old_merged.exists() and cycle % config['save_every_n_cycles'] != 1:
527
+ shutil.rmtree(old_merged)
528
+
529
+ print(f" Cycle {cycle} complete. New base: {current_base}")
530
+
531
+ # Final save
532
+ final_path = output_dir / "final_model"
533
+ shutil.copytree(current_base, final_path, dirs_exist_ok=True)
534
+
535
+ # Summary
536
+ print("\n" + "="*60)
537
+ print("TRAINING COMPLETE")
538
+ print("="*60)
539
+ print(f"Total cycles: {num_cycles}")
540
+ print(f"Best loss: {best_loss:.4f} (cycle {best_cycle})")
541
+ print(f"Final model: {final_path}")
542
+ print("="*60)
543
+
544
+ # Save final state
545
+ with open(output_dir / "results.json", 'w') as f:
546
+ json.dump({
547
+ "total_cycles": num_cycles,
548
+ "best_loss": best_loss,
549
+ "best_cycle": best_cycle,
550
+ "final_model": str(final_path),
551
+ "completed": datetime.now().isoformat()
552
+ }, f, indent=2)
553
+
554
+ return str(final_path)
555
+
556
+
557
+ # =============================================================================
558
+ # CLI
559
+ # =============================================================================
560
+
561
+ def main():
562
+ parser = argparse.ArgumentParser(
563
+ description="Progressive LoRA Merging - Complete model identity replacement",
564
+ formatter_class=argparse.RawDescriptionHelpFormatter,
565
+ epilog="""
566
+ Examples:
567
+ python plm.py --base-model Qwen/Qwen3-1.7B --dataset data.jsonl --cycles 100
568
+ python plm.py --base-model meta-llama/Llama-3-8B --dataset data.jsonl --cycles 50 --lora-r 16
569
+
570
+ Dataset format (JSONL, any of these):
571
+ {"text": "full conversation text"}
572
+ {"prompt": "user input", "response": "assistant output"}
573
+ {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
574
+
575
+ Paper: "Body Snatching: Complete Model Identity Replacement via Progressive LoRA Merging"
576
+ Author: Ouissam Said Drissi (wissam.idrissi@gmail.com)
577
+ """
578
+ )
579
+
580
+ # Required
581
+ parser.add_argument("--base-model", required=True, help="Base model path or HF model ID")
582
+ parser.add_argument("--dataset", required=True, help="Path to JSONL dataset")
583
+ parser.add_argument("--cycles", type=int, required=True, help="Number of train-merge cycles")
584
+
585
+ # Optional
586
+ parser.add_argument("--output-dir", default="./plm_output", help="Output directory")
587
+ parser.add_argument("--lora-r", type=int, default=8, help="LoRA rank")
588
+ parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
589
+ parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate")
590
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
591
+ parser.add_argument("--max-length", type=int, default=4096, help="Max sequence length")
592
+ parser.add_argument("--epochs-per-cycle", type=int, default=1, help="Epochs per cycle")
593
+ parser.add_argument("--save-every", type=int, default=5, help="Save checkpoint every N cycles")
594
+
595
+ args = parser.parse_args()
596
+
597
+ # Build config
598
+ config = DEFAULT_CONFIG.copy()
599
+ config.update({
600
+ "output_dir": args.output_dir,
601
+ "lora_r": args.lora_r,
602
+ "lora_alpha": args.lora_alpha,
603
+ "learning_rate": args.learning_rate,
604
+ "batch_size": args.batch_size,
605
+ "max_length": args.max_length,
606
+ "epochs_per_cycle": args.epochs_per_cycle,
607
+ "save_every_n_cycles": args.save_every,
608
+ })
609
+
610
+ # Run
611
+ final_model = progressive_lora_merge(
612
+ base_model=args.base_model,
613
+ dataset_path=args.dataset,
614
+ num_cycles=args.cycles,
615
+ config=config
616
+ )
617
+
618
+ print(f"\nDone! Final model at: {final_model}")
619
+
620
+
621
+ if __name__ == "__main__":
622
+ main()