td-toolkit / td_fuse /run.py
td-builder's picture
Fixed code: vocab mismatch fix for cross-arch merging (Llama/Falcon)
5d61448 verified
"""
TD Fuse — Main Entry Point.
Usage:
# Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk)
python -m td_fuse.run --stage demo
# Full pipeline: all 4 merges
python -m td_fuse.run --stage all
# Single model merge
python -m td_fuse.run --stage deepseek
python -m td_fuse.run --stage mimo
python -m td_fuse.run --stage llama
python -m td_fuse.run --stage falcon
# With healing fine-tune after merge
python -m td_fuse.run --stage demo --heal
# Custom output directory
python -m td_fuse.run --stage all --output ./my_output
# Heal an existing checkpoint
python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek
Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline)
"""
import argparse
import json
import sys
import time
from pathlib import Path
from .config import MergeConfig, DEMO_STAGES, FULL_STAGES
from .merge import run_pipeline, ResidualBank
from .heal import heal_model
def parse_args():
parser = argparse.ArgumentParser(
description="TD Fuse — Transport and Merge pipeline for Time Dilation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m td_fuse.run --stage demo # Dad demo (DeepSeek only)
python -m td_fuse.run --stage all # Full 4-model merge
python -m td_fuse.run --stage all --heal # Merge + healing fine-tune
python -m td_fuse.run --heal-only --model-path ./checkpoint
python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final
""",
)
parser.add_argument(
"--stage",
type=str,
default="demo",
choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"],
help="Which merge stage(s) to run (default: demo)",
)
parser.add_argument(
"--heal",
action="store_true",
help="Run healing fine-tune after merge",
)
parser.add_argument(
"--heal-only",
action="store_true",
help="Only run healing (skip merge), requires --model-path",
)
parser.add_argument(
"--model-path",
type=str,
default=None,
help="Path to existing model/checkpoint (for --heal-only)",
)
parser.add_argument(
"--output",
type=str,
default="./td_fuse_outputs",
help="Output directory (default: ./td_fuse_outputs)",
)
parser.add_argument(
"--checkpoint-dir",
type=str,
default="./td_fuse_checkpoints",
help="Checkpoint directory (default: ./td_fuse_checkpoints)",
)
parser.add_argument(
"--tm-repo",
type=str,
default="./Cross-Architecture-Merging-for-Large-Language-Models",
help="Path to official T&M repo",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print what would happen without actually running",
)
parser.add_argument(
"--reinject",
type=str,
default=None,
help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)",
)
parser.add_argument(
"--reinject-side",
type=str,
default="both",
choices=["target", "source", "both"],
help="Which side's residuals to re-inject (default: both)",
)
parser.add_argument(
"--strength",
type=float,
default=0.2,
help="Residual re-injection strength, 0-1 (default: 0.2)",
)
return parser.parse_args()
def print_banner():
"""Print the TD Fuse banner."""
banner = """
╔══════════════════════════════════════════════════╗
║ ║
║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║
║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║
║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║
║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║
║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║
║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║
║ ║
║ Transport and Merge for Time Dilation ║
║ Merging 5 models into Qwen3-8B ║
║ ║
╚══════════════════════════════════════════════════╝
"""
print(banner)
def main():
args = parse_args()
print_banner()
# Build config from args
cfg = MergeConfig(
output_dir=args.output,
checkpoint_dir=args.checkpoint_dir,
tm_repo_path=args.tm_repo,
)
# Determine which stages to run
if args.stage == "demo":
stages = DEMO_STAGES
elif args.stage == "all":
stages = FULL_STAGES
else:
stages = [args.stage]
# --- Reinject residuals mode ---
if args.reinject:
if not args.model_path:
print("Error: --reinject requires --model-path")
sys.exit(1)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print(f"\n[run] Re-injecting residuals from stage: {args.reinject}")
print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}")
residual_bank = ResidualBank(cfg)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = residual_bank.reinject_residuals(
model, args.reinject,
side=args.reinject_side,
strength=args.strength,
)
# Save the patched model
patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}"
patched_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(patched_dir))
tokenizer.save_pretrained(str(patched_dir))
print(f"\n[run] Patched model saved to: {patched_dir}")
return
# --- Heal-only mode ---
if args.heal_only:
if not args.model_path:
print("Error: --heal-only requires --model-path")
sys.exit(1)
print(f"\n[run] Healing model at: {args.model_path}")
healed_path = heal_model(args.model_path, cfg)
print(f"\n[run] Healed model saved to: {healed_path}")
return
# --- Dry run ---
if args.dry_run:
print("\n=== DRY RUN ===")
print(f"Stages: {stages}")
print(f"Output: {cfg.output_dir}")
print(f"Checkpoints: {cfg.checkpoint_dir}")
print(f"T&M repo: {cfg.tm_repo_path}")
print(f"Heal after: {args.heal}")
print(f"\nWould run:")
for i, stage in enumerate(stages, 1):
print(f" {i}. Merge {stage} → target")
print(f" → Validate (canary + perplexity + thinking + reasoning)")
print(f" → Checkpoint")
if args.heal:
print(f" {len(stages) + 1}. QLoRA healing fine-tune")
print("\nNo changes made (dry run).")
return
# --- Run the pipeline ---
start_time = time.time()
results = run_pipeline(stages, cfg)
elapsed = time.time() - start_time
print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes")
# --- Healing fine-tune (optional) ---
if args.heal and results.get("final_checkpoint"):
print("\n[run] Starting healing fine-tune...")
healed_path = heal_model(results["final_checkpoint"], cfg)
results["healed_model_path"] = healed_path
print(f"[run] Healed model: {healed_path}")
# --- Save results ---
results_path = Path(cfg.output_dir) / "pipeline_results.json"
# Convert non-serialisable objects
def make_serialisable(obj):
if isinstance(obj, dict):
return {k: make_serialisable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [make_serialisable(v) for v in obj]
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
else:
return str(obj)
with open(results_path, "w") as f:
json.dump(make_serialisable(results), f, indent=2)
print(f"[run] Results saved to {results_path}")
# --- Final summary ---
print(f"\n{'=' * 60}")
print("TD FUSE COMPLETE")
print(f"{'=' * 60}")
print(f" Status: {results['overall_status']}")
print(f" Time: {elapsed / 60:.1f} minutes")
if results.get("final_model_path"):
print(f" Model: {results['final_model_path']}")
if results.get("healed_model_path"):
print(f" Healed: {results['healed_model_path']}")
print(f" Results: {results_path}")
print(f"{'=' * 60}")
# Exit code based on result
if results["overall_status"] == "all_passed":
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()