""" TD Fuse — Main Entry Point. Usage: # Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk) python -m td_fuse.run --stage demo # Full pipeline: all 4 merges python -m td_fuse.run --stage all # Single model merge python -m td_fuse.run --stage deepseek python -m td_fuse.run --stage mimo python -m td_fuse.run --stage llama python -m td_fuse.run --stage falcon # With healing fine-tune after merge python -m td_fuse.run --stage demo --heal # Custom output directory python -m td_fuse.run --stage all --output ./my_output # Heal an existing checkpoint python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline) """ import argparse import json import sys import time from pathlib import Path from .config import MergeConfig, DEMO_STAGES, FULL_STAGES from .merge import run_pipeline, ResidualBank from .heal import heal_model def parse_args(): parser = argparse.ArgumentParser( description="TD Fuse — Transport and Merge pipeline for Time Dilation", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python -m td_fuse.run --stage demo # Dad demo (DeepSeek only) python -m td_fuse.run --stage all # Full 4-model merge python -m td_fuse.run --stage all --heal # Merge + healing fine-tune python -m td_fuse.run --heal-only --model-path ./checkpoint python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final """, ) parser.add_argument( "--stage", type=str, default="demo", choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"], help="Which merge stage(s) to run (default: demo)", ) parser.add_argument( "--heal", action="store_true", help="Run healing fine-tune after merge", ) parser.add_argument( "--heal-only", action="store_true", help="Only run healing (skip merge), requires --model-path", ) parser.add_argument( "--model-path", type=str, default=None, help="Path to existing model/checkpoint (for --heal-only)", ) parser.add_argument( "--output", type=str, default="./td_fuse_outputs", help="Output directory (default: ./td_fuse_outputs)", ) parser.add_argument( "--checkpoint-dir", type=str, default="./td_fuse_checkpoints", help="Checkpoint directory (default: ./td_fuse_checkpoints)", ) parser.add_argument( "--tm-repo", type=str, default="./Cross-Architecture-Merging-for-Large-Language-Models", help="Path to official T&M repo", ) parser.add_argument( "--dry-run", action="store_true", help="Print what would happen without actually running", ) parser.add_argument( "--reinject", type=str, default=None, help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)", ) parser.add_argument( "--reinject-side", type=str, default="both", choices=["target", "source", "both"], help="Which side's residuals to re-inject (default: both)", ) parser.add_argument( "--strength", type=float, default=0.2, help="Residual re-injection strength, 0-1 (default: 0.2)", ) return parser.parse_args() def print_banner(): """Print the TD Fuse banner.""" banner = """ ╔══════════════════════════════════════════════════╗ ║ ║ ║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║ ║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║ ║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║ ║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║ ║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║ ║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║ ║ ║ ║ Transport and Merge for Time Dilation ║ ║ Merging 5 models into Qwen3-8B ║ ║ ║ ╚══════════════════════════════════════════════════╝ """ print(banner) def main(): args = parse_args() print_banner() # Build config from args cfg = MergeConfig( output_dir=args.output, checkpoint_dir=args.checkpoint_dir, tm_repo_path=args.tm_repo, ) # Determine which stages to run if args.stage == "demo": stages = DEMO_STAGES elif args.stage == "all": stages = FULL_STAGES else: stages = [args.stage] # --- Reinject residuals mode --- if args.reinject: if not args.model_path: print("Error: --reinject requires --model-path") sys.exit(1) from transformers import AutoModelForCausalLM, AutoTokenizer import torch print(f"\n[run] Re-injecting residuals from stage: {args.reinject}") print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}") residual_bank = ResidualBank(cfg) tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.bfloat16, device_map="auto", ) model = residual_bank.reinject_residuals( model, args.reinject, side=args.reinject_side, strength=args.strength, ) # Save the patched model patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}" patched_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(patched_dir)) tokenizer.save_pretrained(str(patched_dir)) print(f"\n[run] Patched model saved to: {patched_dir}") return # --- Heal-only mode --- if args.heal_only: if not args.model_path: print("Error: --heal-only requires --model-path") sys.exit(1) print(f"\n[run] Healing model at: {args.model_path}") healed_path = heal_model(args.model_path, cfg) print(f"\n[run] Healed model saved to: {healed_path}") return # --- Dry run --- if args.dry_run: print("\n=== DRY RUN ===") print(f"Stages: {stages}") print(f"Output: {cfg.output_dir}") print(f"Checkpoints: {cfg.checkpoint_dir}") print(f"T&M repo: {cfg.tm_repo_path}") print(f"Heal after: {args.heal}") print(f"\nWould run:") for i, stage in enumerate(stages, 1): print(f" {i}. Merge {stage} → target") print(f" → Validate (canary + perplexity + thinking + reasoning)") print(f" → Checkpoint") if args.heal: print(f" {len(stages) + 1}. QLoRA healing fine-tune") print("\nNo changes made (dry run).") return # --- Run the pipeline --- start_time = time.time() results = run_pipeline(stages, cfg) elapsed = time.time() - start_time print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes") # --- Healing fine-tune (optional) --- if args.heal and results.get("final_checkpoint"): print("\n[run] Starting healing fine-tune...") healed_path = heal_model(results["final_checkpoint"], cfg) results["healed_model_path"] = healed_path print(f"[run] Healed model: {healed_path}") # --- Save results --- results_path = Path(cfg.output_dir) / "pipeline_results.json" # Convert non-serialisable objects def make_serialisable(obj): if isinstance(obj, dict): return {k: make_serialisable(v) for k, v in obj.items()} elif isinstance(obj, list): return [make_serialisable(v) for v in obj] elif isinstance(obj, (int, float, str, bool, type(None))): return obj else: return str(obj) with open(results_path, "w") as f: json.dump(make_serialisable(results), f, indent=2) print(f"[run] Results saved to {results_path}") # --- Final summary --- print(f"\n{'=' * 60}") print("TD FUSE COMPLETE") print(f"{'=' * 60}") print(f" Status: {results['overall_status']}") print(f" Time: {elapsed / 60:.1f} minutes") if results.get("final_model_path"): print(f" Model: {results['final_model_path']}") if results.get("healed_model_path"): print(f" Healed: {results['healed_model_path']}") print(f" Results: {results_path}") print(f"{'=' * 60}") # Exit code based on result if results["overall_status"] == "all_passed": sys.exit(0) else: sys.exit(1) if __name__ == "__main__": main()