import argparse import os import subprocess import sys import time import torch from bit_transformer.utils import load_model from bit_transformer.hf_checkpoint import ( hf_login, save_checkpoint, download_checkpoint, ) from bit_transformer import diffusion_inference from integration_schedule import integration_schedule def _launch_dashboard() -> list[subprocess.Popen]: """Start MCP server and dashboard processes.""" server = subprocess.Popen([sys.executable, "mcp_server.py"]) time.sleep(2) dash_env = dict(os.environ) dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000") dashboard = subprocess.Popen( [sys.executable, "-m", "bit_transformer.dashboard_app"], env=dash_env, ) return [server, dashboard] def _terminate(procs: list[subprocess.Popen]) -> None: for p in procs: p.terminate() try: p.wait(timeout=5) except Exception: p.kill() def run_workflow( steps: int = 10, max_len: int = 64, dataset_size: int = 128, *, launch_ui: bool = False, weights_path: str = "weights/model.pt.gz", collapsed_path: str = "weights/collapsed.pt.gz", plateau_steps: int = 0, epochs_per_step: int = 2, extra_steps: int = 3, collapse: bool = True, hf_repo: str | None = None, hf_token: str | None = None, diffusion: bool = False, noise_schedule: str = "linear", diffusion_steps: int = 8, diffusion_curriculum: bool = False, use_checkpoint: bool = True, reversible: bool = True, qat: bool = False, ) -> tuple: """Run the full integration schedule with optional dashboard. If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training before being converted to quantized weights for safety checks. """ procs: list[subprocess.Popen] = [] if launch_ui: procs = _launch_dashboard() if hf_repo: hf_login(token=hf_token) if not os.path.exists(weights_path): download_checkpoint(weights_path, repo_id=hf_repo) try: results, collapsed = integration_schedule( steps=steps, max_len=max_len, dataset_size=dataset_size, weights_path=weights_path, plateau_steps=plateau_steps, collapsed_path=collapsed_path, epochs_per_step=epochs_per_step, extra_steps=extra_steps, collapse=collapse, diffusion=diffusion, noise_schedule=noise_schedule, diffusion_steps=diffusion_steps, diffusion_curriculum=diffusion_curriculum, use_checkpoint=use_checkpoint, reversible=reversible, qat=qat, ) model = load_model(weights_path) print("Workflow results:", results) if diffusion: sample = diffusion_inference( model, length=max_len, steps=diffusion_steps, schedule=noise_schedule ) print("Diffusion inference output bits:", sample[0].tolist()) if hf_repo: save_checkpoint(model, repo_id=hf_repo) finally: if launch_ui: _terminate(procs) return model, collapsed if __name__ == "__main__": parser = argparse.ArgumentParser(description="Unified end-to-end workflow for BitTransformerLM") parser.add_argument("--steps", type=int, default=10, help="number of scale-up steps") parser.add_argument("--max-len", type=int, default=64, help="sequence length") parser.add_argument("--dataset-size", type=int, default=128, help="training dataset size") parser.add_argument("--dashboard", action="store_true", help="launch MCP server and dashboard") parser.add_argument("--plateau-steps", type=int, default=0, help="extra training steps at final size") parser.add_argument("--weights-path", type=str, default="weights/model.pt.gz", help="model weights file") parser.add_argument("--collapsed-path", type=str, default="weights/collapsed.pt.gz", help="collapsed model file") parser.add_argument("--epochs-per-step", type=int, default=2, help="epochs per training step") parser.add_argument("--extra-steps", type=int, default=3, help="optimizer updates after each epoch") parser.add_argument("--no-collapse", action="store_true", help="skip collapsed model generation") parser.add_argument("--hf-repo", type=str, help="Hugging Face repository for checkpoints") parser.add_argument("--hf-token", type=str, default=None, help="Authentication token for Hugging Face") parser.add_argument( "--diffusion", action="store_true", help="enable Diffusion LM (non-causal) mode", ) parser.add_argument( "--noise-schedule", type=str, default="linear", choices=["linear", "cosine", "exp"], help="noise schedule for diffusion mode", ) parser.add_argument( "--diffusion-steps", type=int, default=8, help="number of denoising steps for diffusion mode", ) parser.add_argument( "--diffusion-curriculum", action="store_true", help="linearly decay noise over diffusion training epochs", ) parser.add_argument( "--no-checkpoint", action="store_true", help="disable gradient checkpointing for faster but memory-heavy runs", ) parser.add_argument( "--no-reversible", action="store_true", help="use standard transformer blocks instead of reversible layers", ) parser.add_argument( "--qat", action="store_true", help="enable 4-bit quantization-aware training", ) args = parser.parse_args() run_workflow( args.steps, args.max_len, args.dataset_size, launch_ui=args.dashboard, weights_path=args.weights_path, collapsed_path=args.collapsed_path, plateau_steps=args.plateau_steps, epochs_per_step=args.epochs_per_step, extra_steps=args.extra_steps, collapse=not args.no_collapse, hf_repo=args.hf_repo, hf_token=args.hf_token, diffusion=args.diffusion, noise_schedule=args.noise_schedule, diffusion_steps=args.diffusion_steps, diffusion_curriculum=args.diffusion_curriculum, use_checkpoint=not args.no_checkpoint, reversible=not args.no_reversible, qat=args.qat, )