WCNegentropy commited on
Commit
e83fc25
·
verified ·
1 Parent(s): 3e0a2e7

Remove nested directory: BitTransformerLM/unified_workflow.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/unified_workflow.py +0 -182
BitTransformerLM/unified_workflow.py DELETED
@@ -1,182 +0,0 @@
1
- import argparse
2
- import os
3
- import subprocess
4
- import sys
5
- import time
6
- import torch
7
- from bit_transformer.utils import load_model
8
- from bit_transformer.hf_checkpoint import (
9
- hf_login,
10
- save_checkpoint,
11
- download_checkpoint,
12
- )
13
- from bit_transformer import diffusion_inference
14
-
15
- from integration_schedule import integration_schedule
16
-
17
-
18
- def _launch_dashboard() -> list[subprocess.Popen]:
19
- """Start MCP server and dashboard processes."""
20
- server = subprocess.Popen([sys.executable, "mcp_server.py"])
21
- time.sleep(2)
22
- dash_env = dict(os.environ)
23
- dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
24
- dashboard = subprocess.Popen(
25
- [sys.executable, "-m", "bit_transformer.dashboard_app"],
26
- env=dash_env,
27
- )
28
- return [server, dashboard]
29
-
30
-
31
- def _terminate(procs: list[subprocess.Popen]) -> None:
32
- for p in procs:
33
- p.terminate()
34
- try:
35
- p.wait(timeout=5)
36
- except Exception:
37
- p.kill()
38
-
39
-
40
- def run_workflow(
41
- steps: int = 10,
42
- max_len: int = 64,
43
- dataset_size: int = 128,
44
- *,
45
- launch_ui: bool = False,
46
- weights_path: str = "weights/model.pt.gz",
47
- collapsed_path: str = "weights/collapsed.pt.gz",
48
- plateau_steps: int = 0,
49
- epochs_per_step: int = 2,
50
- extra_steps: int = 3,
51
- collapse: bool = True,
52
- hf_repo: str | None = None,
53
- hf_token: str | None = None,
54
- diffusion: bool = False,
55
- noise_schedule: str = "linear",
56
- diffusion_steps: int = 8,
57
- diffusion_curriculum: bool = False,
58
- use_checkpoint: bool = True,
59
- reversible: bool = True,
60
- qat: bool = False,
61
- ) -> tuple:
62
- """Run the full integration schedule with optional dashboard.
63
-
64
- If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
65
- before being converted to quantized weights for safety checks.
66
- """
67
- procs: list[subprocess.Popen] = []
68
- if launch_ui:
69
- procs = _launch_dashboard()
70
- if hf_repo:
71
- hf_login(token=hf_token)
72
- if not os.path.exists(weights_path):
73
- download_checkpoint(weights_path, repo_id=hf_repo)
74
- try:
75
- results, collapsed = integration_schedule(
76
- steps=steps,
77
- max_len=max_len,
78
- dataset_size=dataset_size,
79
- weights_path=weights_path,
80
- plateau_steps=plateau_steps,
81
- collapsed_path=collapsed_path,
82
- epochs_per_step=epochs_per_step,
83
- extra_steps=extra_steps,
84
- collapse=collapse,
85
- diffusion=diffusion,
86
- noise_schedule=noise_schedule,
87
- diffusion_steps=diffusion_steps,
88
- diffusion_curriculum=diffusion_curriculum,
89
- use_checkpoint=use_checkpoint,
90
- reversible=reversible,
91
- qat=qat,
92
- )
93
- model = load_model(weights_path)
94
- print("Workflow results:", results)
95
- if diffusion:
96
- sample = diffusion_inference(
97
- model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
98
- )
99
- print("Diffusion inference output bits:", sample[0].tolist())
100
- if hf_repo:
101
- save_checkpoint(model, repo_id=hf_repo)
102
- finally:
103
- if launch_ui:
104
- _terminate(procs)
105
- return model, collapsed
106
-
107
-
108
- if __name__ == "__main__":
109
- parser = argparse.ArgumentParser(description="Unified end-to-end workflow for BitTransformerLM")
110
- parser.add_argument("--steps", type=int, default=10, help="number of scale-up steps")
111
- parser.add_argument("--max-len", type=int, default=64, help="sequence length")
112
- parser.add_argument("--dataset-size", type=int, default=128, help="training dataset size")
113
- parser.add_argument("--dashboard", action="store_true", help="launch MCP server and dashboard")
114
- parser.add_argument("--plateau-steps", type=int, default=0, help="extra training steps at final size")
115
- parser.add_argument("--weights-path", type=str, default="weights/model.pt.gz", help="model weights file")
116
- parser.add_argument("--collapsed-path", type=str, default="weights/collapsed.pt.gz", help="collapsed model file")
117
- parser.add_argument("--epochs-per-step", type=int, default=2, help="epochs per training step")
118
- parser.add_argument("--extra-steps", type=int, default=3, help="optimizer updates after each epoch")
119
- parser.add_argument("--no-collapse", action="store_true", help="skip collapsed model generation")
120
- parser.add_argument("--hf-repo", type=str, help="Hugging Face repository for checkpoints")
121
- parser.add_argument("--hf-token", type=str, default=None, help="Authentication token for Hugging Face")
122
- parser.add_argument(
123
- "--diffusion",
124
- action="store_true",
125
- help="enable Diffusion LM (non-causal) mode",
126
- )
127
- parser.add_argument(
128
- "--noise-schedule",
129
- type=str,
130
- default="linear",
131
- choices=["linear", "cosine", "exp"],
132
- help="noise schedule for diffusion mode",
133
- )
134
- parser.add_argument(
135
- "--diffusion-steps",
136
- type=int,
137
- default=8,
138
- help="number of denoising steps for diffusion mode",
139
- )
140
- parser.add_argument(
141
- "--diffusion-curriculum",
142
- action="store_true",
143
- help="linearly decay noise over diffusion training epochs",
144
- )
145
- parser.add_argument(
146
- "--no-checkpoint",
147
- action="store_true",
148
- help="disable gradient checkpointing for faster but memory-heavy runs",
149
- )
150
- parser.add_argument(
151
- "--no-reversible",
152
- action="store_true",
153
- help="use standard transformer blocks instead of reversible layers",
154
- )
155
- parser.add_argument(
156
- "--qat",
157
- action="store_true",
158
- help="enable 4-bit quantization-aware training",
159
- )
160
- args = parser.parse_args()
161
-
162
- run_workflow(
163
- args.steps,
164
- args.max_len,
165
- args.dataset_size,
166
- launch_ui=args.dashboard,
167
- weights_path=args.weights_path,
168
- collapsed_path=args.collapsed_path,
169
- plateau_steps=args.plateau_steps,
170
- epochs_per_step=args.epochs_per_step,
171
- extra_steps=args.extra_steps,
172
- collapse=not args.no_collapse,
173
- hf_repo=args.hf_repo,
174
- hf_token=args.hf_token,
175
- diffusion=args.diffusion,
176
- noise_schedule=args.noise_schedule,
177
- diffusion_steps=args.diffusion_steps,
178
- diffusion_curriculum=args.diffusion_curriculum,
179
- use_checkpoint=not args.no_checkpoint,
180
- reversible=not args.no_reversible,
181
- qat=args.qat,
182
- )