BitTransformerLM / BitTransformerLM /unified_workflow.py
WCNegentropy's picture
Upload 65 files
12e8f96 verified
raw
history blame
6.47 kB
import argparse
import os
import subprocess
import sys
import time
import torch
from bit_transformer.utils import load_model
from bit_transformer.hf_checkpoint import (
hf_login,
save_checkpoint,
download_checkpoint,
)
from bit_transformer import diffusion_inference
from integration_schedule import integration_schedule
def _launch_dashboard() -> list[subprocess.Popen]:
"""Start MCP server and dashboard processes."""
server = subprocess.Popen([sys.executable, "mcp_server.py"])
time.sleep(2)
dash_env = dict(os.environ)
dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
dashboard = subprocess.Popen(
[sys.executable, "-m", "bit_transformer.dashboard_app"],
env=dash_env,
)
return [server, dashboard]
def _terminate(procs: list[subprocess.Popen]) -> None:
for p in procs:
p.terminate()
try:
p.wait(timeout=5)
except Exception:
p.kill()
def run_workflow(
steps: int = 10,
max_len: int = 64,
dataset_size: int = 128,
*,
launch_ui: bool = False,
weights_path: str = "weights/model.pt.gz",
collapsed_path: str = "weights/collapsed.pt.gz",
plateau_steps: int = 0,
epochs_per_step: int = 2,
extra_steps: int = 3,
collapse: bool = True,
hf_repo: str | None = None,
hf_token: str | None = None,
diffusion: bool = False,
noise_schedule: str = "linear",
diffusion_steps: int = 8,
diffusion_curriculum: bool = False,
use_checkpoint: bool = True,
reversible: bool = True,
qat: bool = False,
) -> tuple:
"""Run the full integration schedule with optional dashboard.
If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
before being converted to quantized weights for safety checks.
"""
procs: list[subprocess.Popen] = []
if launch_ui:
procs = _launch_dashboard()
if hf_repo:
hf_login(token=hf_token)
if not os.path.exists(weights_path):
download_checkpoint(weights_path, repo_id=hf_repo)
try:
results, collapsed = integration_schedule(
steps=steps,
max_len=max_len,
dataset_size=dataset_size,
weights_path=weights_path,
plateau_steps=plateau_steps,
collapsed_path=collapsed_path,
epochs_per_step=epochs_per_step,
extra_steps=extra_steps,
collapse=collapse,
diffusion=diffusion,
noise_schedule=noise_schedule,
diffusion_steps=diffusion_steps,
diffusion_curriculum=diffusion_curriculum,
use_checkpoint=use_checkpoint,
reversible=reversible,
qat=qat,
)
model = load_model(weights_path)
print("Workflow results:", results)
if diffusion:
sample = diffusion_inference(
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
)
print("Diffusion inference output bits:", sample[0].tolist())
if hf_repo:
save_checkpoint(model, repo_id=hf_repo)
finally:
if launch_ui:
_terminate(procs)
return model, collapsed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Unified end-to-end workflow for BitTransformerLM")
parser.add_argument("--steps", type=int, default=10, help="number of scale-up steps")
parser.add_argument("--max-len", type=int, default=64, help="sequence length")
parser.add_argument("--dataset-size", type=int, default=128, help="training dataset size")
parser.add_argument("--dashboard", action="store_true", help="launch MCP server and dashboard")
parser.add_argument("--plateau-steps", type=int, default=0, help="extra training steps at final size")
parser.add_argument("--weights-path", type=str, default="weights/model.pt.gz", help="model weights file")
parser.add_argument("--collapsed-path", type=str, default="weights/collapsed.pt.gz", help="collapsed model file")
parser.add_argument("--epochs-per-step", type=int, default=2, help="epochs per training step")
parser.add_argument("--extra-steps", type=int, default=3, help="optimizer updates after each epoch")
parser.add_argument("--no-collapse", action="store_true", help="skip collapsed model generation")
parser.add_argument("--hf-repo", type=str, help="Hugging Face repository for checkpoints")
parser.add_argument("--hf-token", type=str, default=None, help="Authentication token for Hugging Face")
parser.add_argument(
"--diffusion",
action="store_true",
help="enable Diffusion LM (non-causal) mode",
)
parser.add_argument(
"--noise-schedule",
type=str,
default="linear",
choices=["linear", "cosine", "exp"],
help="noise schedule for diffusion mode",
)
parser.add_argument(
"--diffusion-steps",
type=int,
default=8,
help="number of denoising steps for diffusion mode",
)
parser.add_argument(
"--diffusion-curriculum",
action="store_true",
help="linearly decay noise over diffusion training epochs",
)
parser.add_argument(
"--no-checkpoint",
action="store_true",
help="disable gradient checkpointing for faster but memory-heavy runs",
)
parser.add_argument(
"--no-reversible",
action="store_true",
help="use standard transformer blocks instead of reversible layers",
)
parser.add_argument(
"--qat",
action="store_true",
help="enable 4-bit quantization-aware training",
)
args = parser.parse_args()
run_workflow(
args.steps,
args.max_len,
args.dataset_size,
launch_ui=args.dashboard,
weights_path=args.weights_path,
collapsed_path=args.collapsed_path,
plateau_steps=args.plateau_steps,
epochs_per_step=args.epochs_per_step,
extra_steps=args.extra_steps,
collapse=not args.no_collapse,
hf_repo=args.hf_repo,
hf_token=args.hf_token,
diffusion=args.diffusion,
noise_schedule=args.noise_schedule,
diffusion_steps=args.diffusion_steps,
diffusion_curriculum=args.diffusion_curriculum,
use_checkpoint=not args.no_checkpoint,
reversible=not args.no_reversible,
qat=args.qat,
)