| """ |
| TD Fuse — Main Entry Point. |
| |
| Usage: |
| # Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk) |
| python -m td_fuse.run --stage demo |
| |
| # Full pipeline: all 4 merges |
| python -m td_fuse.run --stage all |
| |
| # Single model merge |
| python -m td_fuse.run --stage deepseek |
| python -m td_fuse.run --stage mimo |
| python -m td_fuse.run --stage llama |
| python -m td_fuse.run --stage falcon |
| |
| # With healing fine-tune after merge |
| python -m td_fuse.run --stage demo --heal |
| |
| # Custom output directory |
| python -m td_fuse.run --stage all --output ./my_output |
| |
| # Heal an existing checkpoint |
| python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek |
| |
| Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline) |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from pathlib import Path |
|
|
| from .config import MergeConfig, DEMO_STAGES, FULL_STAGES |
| from .merge import run_pipeline, ResidualBank |
| from .heal import heal_model |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="TD Fuse — Transport and Merge pipeline for Time Dilation", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| python -m td_fuse.run --stage demo # Dad demo (DeepSeek only) |
| python -m td_fuse.run --stage all # Full 4-model merge |
| python -m td_fuse.run --stage all --heal # Merge + healing fine-tune |
| python -m td_fuse.run --heal-only --model-path ./checkpoint |
| python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--stage", |
| type=str, |
| default="demo", |
| choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"], |
| help="Which merge stage(s) to run (default: demo)", |
| ) |
| parser.add_argument( |
| "--heal", |
| action="store_true", |
| help="Run healing fine-tune after merge", |
| ) |
| parser.add_argument( |
| "--heal-only", |
| action="store_true", |
| help="Only run healing (skip merge), requires --model-path", |
| ) |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| default=None, |
| help="Path to existing model/checkpoint (for --heal-only)", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default="./td_fuse_outputs", |
| help="Output directory (default: ./td_fuse_outputs)", |
| ) |
| parser.add_argument( |
| "--checkpoint-dir", |
| type=str, |
| default="./td_fuse_checkpoints", |
| help="Checkpoint directory (default: ./td_fuse_checkpoints)", |
| ) |
| parser.add_argument( |
| "--tm-repo", |
| type=str, |
| default="./Cross-Architecture-Merging-for-Large-Language-Models", |
| help="Path to official T&M repo", |
| ) |
| parser.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Print what would happen without actually running", |
| ) |
| parser.add_argument( |
| "--reinject", |
| type=str, |
| default=None, |
| help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)", |
| ) |
| parser.add_argument( |
| "--reinject-side", |
| type=str, |
| default="both", |
| choices=["target", "source", "both"], |
| help="Which side's residuals to re-inject (default: both)", |
| ) |
| parser.add_argument( |
| "--strength", |
| type=float, |
| default=0.2, |
| help="Residual re-injection strength, 0-1 (default: 0.2)", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def print_banner(): |
| """Print the TD Fuse banner.""" |
| banner = """ |
| ╔══════════════════════════════════════════════════╗ |
| ║ ║ |
| ║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║ |
| ║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║ |
| ║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║ |
| ║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║ |
| ║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║ |
| ║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║ |
| ║ ║ |
| ║ Transport and Merge for Time Dilation ║ |
| ║ Merging 5 models into Qwen3-8B ║ |
| ║ ║ |
| ╚══════════════════════════════════════════════════╝ |
| """ |
| print(banner) |
|
|
|
|
| def main(): |
| args = parse_args() |
| print_banner() |
|
|
| |
| cfg = MergeConfig( |
| output_dir=args.output, |
| checkpoint_dir=args.checkpoint_dir, |
| tm_repo_path=args.tm_repo, |
| ) |
|
|
| |
| if args.stage == "demo": |
| stages = DEMO_STAGES |
| elif args.stage == "all": |
| stages = FULL_STAGES |
| else: |
| stages = [args.stage] |
|
|
| |
| if args.reinject: |
| if not args.model_path: |
| print("Error: --reinject requires --model-path") |
| sys.exit(1) |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
| print(f"\n[run] Re-injecting residuals from stage: {args.reinject}") |
| print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}") |
|
|
| residual_bank = ResidualBank(cfg) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
|
|
| model = residual_bank.reinject_residuals( |
| model, args.reinject, |
| side=args.reinject_side, |
| strength=args.strength, |
| ) |
|
|
| |
| patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}" |
| patched_dir.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(str(patched_dir)) |
| tokenizer.save_pretrained(str(patched_dir)) |
| print(f"\n[run] Patched model saved to: {patched_dir}") |
| return |
|
|
| |
| if args.heal_only: |
| if not args.model_path: |
| print("Error: --heal-only requires --model-path") |
| sys.exit(1) |
|
|
| print(f"\n[run] Healing model at: {args.model_path}") |
| healed_path = heal_model(args.model_path, cfg) |
| print(f"\n[run] Healed model saved to: {healed_path}") |
| return |
|
|
| |
| if args.dry_run: |
| print("\n=== DRY RUN ===") |
| print(f"Stages: {stages}") |
| print(f"Output: {cfg.output_dir}") |
| print(f"Checkpoints: {cfg.checkpoint_dir}") |
| print(f"T&M repo: {cfg.tm_repo_path}") |
| print(f"Heal after: {args.heal}") |
| print(f"\nWould run:") |
| for i, stage in enumerate(stages, 1): |
| print(f" {i}. Merge {stage} → target") |
| print(f" → Validate (canary + perplexity + thinking + reasoning)") |
| print(f" → Checkpoint") |
| if args.heal: |
| print(f" {len(stages) + 1}. QLoRA healing fine-tune") |
| print("\nNo changes made (dry run).") |
| return |
|
|
| |
| start_time = time.time() |
|
|
| results = run_pipeline(stages, cfg) |
|
|
| elapsed = time.time() - start_time |
| print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes") |
|
|
| |
| if args.heal and results.get("final_checkpoint"): |
| print("\n[run] Starting healing fine-tune...") |
| healed_path = heal_model(results["final_checkpoint"], cfg) |
| results["healed_model_path"] = healed_path |
| print(f"[run] Healed model: {healed_path}") |
|
|
| |
| results_path = Path(cfg.output_dir) / "pipeline_results.json" |
|
|
| |
| def make_serialisable(obj): |
| if isinstance(obj, dict): |
| return {k: make_serialisable(v) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [make_serialisable(v) for v in obj] |
| elif isinstance(obj, (int, float, str, bool, type(None))): |
| return obj |
| else: |
| return str(obj) |
|
|
| with open(results_path, "w") as f: |
| json.dump(make_serialisable(results), f, indent=2) |
| print(f"[run] Results saved to {results_path}") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("TD FUSE COMPLETE") |
| print(f"{'=' * 60}") |
| print(f" Status: {results['overall_status']}") |
| print(f" Time: {elapsed / 60:.1f} minutes") |
| if results.get("final_model_path"): |
| print(f" Model: {results['final_model_path']}") |
| if results.get("healed_model_path"): |
| print(f" Healed: {results['healed_model_path']}") |
| print(f" Results: {results_path}") |
| print(f"{'=' * 60}") |
|
|
| |
| if results["overall_status"] == "all_passed": |
| sys.exit(0) |
| else: |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|