WCNegentropy commited on
Commit
2f70b79
·
verified ·
1 Parent(s): 00c7a97

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. bit_transformer/cli.py +239 -0
bit_transformer/cli.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface entry points for BitTransformerLM."""
2
+
3
+ import sys
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+ from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI
11
+ from .config import (
12
+ ExperimentConfig,
13
+ ModelConfig,
14
+ TrainingConfig,
15
+ SafetyConfig,
16
+ DataConfig,
17
+ get_small_config,
18
+ get_medium_config,
19
+ get_large_config,
20
+ )
21
+ from .model import BitTransformerLM, diffusion_inference
22
+ from .training import train_loop
23
+ from .bit_io import text_to_bits, bits_to_text, infer_text
24
+ from .utils import save_model, load_model
25
+ from .dashboard_app import run_dashboard
26
+
27
+
28
+ def setup_logging(level: str = "INFO") -> None:
29
+ """Setup logging configuration."""
30
+ logging.basicConfig(
31
+ level=getattr(logging, level.upper()),
32
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
33
+ handlers=[
34
+ logging.StreamHandler(sys.stdout),
35
+ ],
36
+ )
37
+
38
+
39
+ def train_cli() -> None:
40
+ """CLI entry point for training BitTransformerLM models."""
41
+ parser = create_training_parser()
42
+ args = parser.parse_args()
43
+
44
+ setup_logging(args.log_level)
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Get preset configuration if specified
48
+ if args.model_size == "small":
49
+ config = get_small_config()
50
+ elif args.model_size == "medium":
51
+ config = get_medium_config()
52
+ elif args.model_size == "large":
53
+ config = get_large_config()
54
+ else:
55
+ config = ExperimentConfig()
56
+
57
+ # Override with command line arguments
58
+ config.model.d_model = args.d_model
59
+ config.model.nhead = args.num_heads
60
+ config.model.num_layers = args.num_layers
61
+ config.model.max_seq_len = args.max_seq_len
62
+
63
+ config.training.epochs = args.epochs
64
+ config.training.batch_size = args.batch_size
65
+ config.training.learning_rate = args.learning_rate
66
+ config.training.weight_decay = args.weight_decay
67
+ config.training.gradient_clip_val = args.grad_clip
68
+ config.training.warmup_steps = args.warmup_steps
69
+ config.training.amp = args.use_amp
70
+ config.training.compile_model = args.compile_model
71
+
72
+ config.safety.k_threshold = args.min_negentropy
73
+ config.safety.c_threshold = args.max_complexity
74
+ config.safety.s_threshold = args.min_symbiosis
75
+ config.safety.enable_safety = args.enable_safety_gates
76
+
77
+ config.data.dataset_path = Path(args.input_path) if args.input_path else None
78
+ config.data.max_sequence_length = args.seq_length
79
+ config.data.num_workers = args.num_workers
80
+
81
+ config.output_dir = Path(args.output_path)
82
+ config.seed = args.seed
83
+
84
+ # Set device
85
+ if torch.cuda.is_available():
86
+ config.device = "cuda"
87
+ else:
88
+ config.device = "cpu"
89
+
90
+ logger.info(f"Starting training with config: {config.experiment_name}")
91
+ logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H")
92
+ logger.info(f"Device: {config.device}")
93
+
94
+ # Create model
95
+ model = BitTransformerLM(**config.model.to_dict())
96
+ model = model.to(config.device)
97
+
98
+ # Create synthetic dataset for demonstration
99
+ logger.info("Creating synthetic training data...")
100
+ torch.manual_seed(config.seed)
101
+ data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length))
102
+
103
+ # Train model
104
+ logger.info("Starting training...")
105
+ try:
106
+ train_loop(
107
+ model,
108
+ data,
109
+ epochs=config.training.epochs,
110
+ batch_size=config.training.batch_size,
111
+ amp=config.training.amp,
112
+ compile_model=config.training.compile_model,
113
+ log=True,
114
+ )
115
+
116
+ # Save model
117
+ save_path = config.output_dir / "model_final.pt"
118
+ save_model(model, save_path)
119
+ logger.info(f"Model saved to {save_path}")
120
+
121
+ except Exception as e:
122
+ logger.error(f"Training failed: {e}")
123
+ sys.exit(1)
124
+
125
+
126
+ def infer_cli() -> None:
127
+ """CLI entry point for BitTransformerLM inference."""
128
+ parser = create_inference_parser()
129
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
130
+ parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate")
131
+ parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
132
+ parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode")
133
+ args = parser.parse_args()
134
+
135
+ setup_logging(args.log_level)
136
+ logger = logging.getLogger(__name__)
137
+
138
+ # Load model
139
+ if not Path(args.weights_path).exists():
140
+ logger.error(f"Model weights not found at {args.weights_path}")
141
+ sys.exit(1)
142
+
143
+ logger.info(f"Loading model from {args.weights_path}")
144
+ model = load_model(args.weights_path)
145
+ model.eval()
146
+
147
+ # Set device
148
+ device = "cuda" if torch.cuda.is_available() else "cpu"
149
+ model = model.to(device)
150
+
151
+ logger.info(f"Model loaded on {device}")
152
+ logger.info(f"Prompt: {args.prompt}")
153
+
154
+ try:
155
+ if args.use_diffusion:
156
+ # Diffusion inference
157
+ logger.info("Using diffusion inference mode")
158
+ prompt_bits = text_to_bits(args.prompt)
159
+ length = len(prompt_bits) + args.max_tokens * 9 # Approximate
160
+
161
+ generated_bits = diffusion_inference(
162
+ model,
163
+ length=length,
164
+ steps=args.diffusion_steps,
165
+ schedule=args.noise_schedule,
166
+ )
167
+
168
+ result = bits_to_text(generated_bits[0].tolist())
169
+
170
+ else:
171
+ # Standard autoregressive inference with safety
172
+ if args.enable_safety_gates:
173
+ result = infer_text(
174
+ model,
175
+ args.prompt,
176
+ c_floor=args.max_complexity,
177
+ s_floor=args.min_symbiosis,
178
+ )
179
+ else:
180
+ # Simple generation without safety gates
181
+ from .bit_io import sample_text
182
+ result = sample_text(
183
+ model,
184
+ args.prompt,
185
+ max_new_tokens=args.max_tokens,
186
+ temperature=args.temperature,
187
+ )
188
+
189
+ print(f"\nGenerated text:\n{result}")
190
+
191
+ except Exception as e:
192
+ logger.error(f"Inference failed: {e}")
193
+ sys.exit(1)
194
+
195
+
196
+ def dashboard_cli() -> None:
197
+ """CLI entry point for BitTransformerLM dashboard."""
198
+ parser = BitTransformerCLI.create_standard_parser(
199
+ "BitTransformerLM Dashboard",
200
+ ["io"]
201
+ )
202
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host")
203
+ parser.add_argument("--port", type=int, default=7860, help="Dashboard port")
204
+ parser.add_argument("--share", action="store_true", help="Create public link")
205
+ args = parser.parse_args()
206
+
207
+ setup_logging(args.log_level)
208
+ logger = logging.getLogger(__name__)
209
+
210
+ logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}")
211
+
212
+ try:
213
+ run_dashboard(
214
+ host=args.host,
215
+ port=args.port,
216
+ share=args.share,
217
+ )
218
+ except Exception as e:
219
+ logger.error(f"Dashboard failed to start: {e}")
220
+ sys.exit(1)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ # Simple dispatcher based on script name
225
+ import os
226
+ script_name = os.path.basename(sys.argv[0])
227
+
228
+ if "train" in script_name:
229
+ train_cli()
230
+ elif "infer" in script_name:
231
+ infer_cli()
232
+ elif "dashboard" in script_name:
233
+ dashboard_cli()
234
+ else:
235
+ print("Available commands:")
236
+ print(" bit-transformer-train - Train a BitTransformerLM model")
237
+ print(" bit-transformer-infer - Run inference with a trained model")
238
+ print(" bit-transformer-dashboard - Launch interactive dashboard")
239
+ sys.exit(1)