| | """Command-line interface entry points for BitTransformerLM.""" |
| |
|
| | import sys |
| | import logging |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| | from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI |
| | from .config import ( |
| | ExperimentConfig, |
| | ModelConfig, |
| | TrainingConfig, |
| | SafetyConfig, |
| | DataConfig, |
| | get_small_config, |
| | get_medium_config, |
| | get_large_config, |
| | ) |
| | from .model import BitTransformerLM, diffusion_inference |
| | from .training import train_loop |
| | from .bit_io import text_to_bits, bits_to_text, infer_text |
| | from .utils import save_model, load_model |
| | from .dashboard_app import run_dashboard |
| |
|
| |
|
| | def setup_logging(level: str = "INFO") -> None: |
| | """Setup logging configuration.""" |
| | logging.basicConfig( |
| | level=getattr(logging, level.upper()), |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | handlers=[ |
| | logging.StreamHandler(sys.stdout), |
| | ], |
| | ) |
| |
|
| |
|
| | def train_cli() -> None: |
| | """CLI entry point for training BitTransformerLM models.""" |
| | parser = create_training_parser() |
| | args = parser.parse_args() |
| |
|
| | setup_logging(args.log_level) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | if args.model_size == "small": |
| | config = get_small_config() |
| | elif args.model_size == "medium": |
| | config = get_medium_config() |
| | elif args.model_size == "large": |
| | config = get_large_config() |
| | else: |
| | config = ExperimentConfig() |
| |
|
| | |
| | config.model.d_model = args.d_model |
| | config.model.nhead = args.num_heads |
| | config.model.num_layers = args.num_layers |
| | config.model.max_seq_len = args.max_seq_len |
| |
|
| | config.training.epochs = args.epochs |
| | config.training.batch_size = args.batch_size |
| | config.training.learning_rate = args.learning_rate |
| | config.training.weight_decay = args.weight_decay |
| | config.training.gradient_clip_val = args.grad_clip |
| | config.training.warmup_steps = args.warmup_steps |
| | config.training.amp = args.use_amp |
| | config.training.compile_model = args.compile_model |
| |
|
| | config.safety.k_threshold = args.min_negentropy |
| | config.safety.c_threshold = args.max_complexity |
| | config.safety.s_threshold = args.min_symbiosis |
| | config.safety.enable_safety = args.enable_safety_gates |
| |
|
| | config.data.dataset_path = Path(args.input_path) if args.input_path else None |
| | config.data.max_sequence_length = args.seq_length |
| | config.data.num_workers = args.num_workers |
| |
|
| | config.output_dir = Path(args.output_path) |
| | config.seed = args.seed |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | config.device = "cuda" |
| | else: |
| | config.device = "cpu" |
| |
|
| | logger.info(f"Starting training with config: {config.experiment_name}") |
| | logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H") |
| | logger.info(f"Device: {config.device}") |
| |
|
| | |
| | model = BitTransformerLM(**config.model.to_dict()) |
| | model = model.to(config.device) |
| |
|
| | |
| | logger.info("Creating synthetic training data...") |
| | torch.manual_seed(config.seed) |
| | data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length)) |
| |
|
| | |
| | logger.info("Starting training...") |
| | try: |
| | train_loop( |
| | model, |
| | data, |
| | epochs=config.training.epochs, |
| | batch_size=config.training.batch_size, |
| | amp=config.training.amp, |
| | compile_model=config.training.compile_model, |
| | log=True, |
| | ) |
| |
|
| | |
| | save_path = config.output_dir / "model_final.pt" |
| | save_model(model, save_path) |
| | logger.info(f"Model saved to {save_path}") |
| |
|
| | except Exception as e: |
| | logger.error(f"Training failed: {e}") |
| | sys.exit(1) |
| |
|
| |
|
| | def infer_cli() -> None: |
| | """CLI entry point for BitTransformerLM inference.""" |
| | parser = create_inference_parser() |
| | parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") |
| | parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate") |
| | parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") |
| | parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode") |
| | args = parser.parse_args() |
| |
|
| | setup_logging(args.log_level) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | if not Path(args.weights_path).exists(): |
| | logger.error(f"Model weights not found at {args.weights_path}") |
| | sys.exit(1) |
| |
|
| | logger.info(f"Loading model from {args.weights_path}") |
| | model = load_model(args.weights_path) |
| | model.eval() |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = model.to(device) |
| |
|
| | logger.info(f"Model loaded on {device}") |
| | logger.info(f"Prompt: {args.prompt}") |
| |
|
| | try: |
| | if args.use_diffusion: |
| | |
| | logger.info("Using diffusion inference mode") |
| | prompt_bits = text_to_bits(args.prompt) |
| | length = len(prompt_bits) + args.max_tokens * 9 |
| |
|
| | generated_bits = diffusion_inference( |
| | model, |
| | length=length, |
| | steps=args.diffusion_steps, |
| | schedule=args.noise_schedule, |
| | ) |
| |
|
| | result = bits_to_text(generated_bits[0].tolist()) |
| |
|
| | else: |
| | |
| | if args.enable_safety_gates: |
| | result = infer_text( |
| | model, |
| | args.prompt, |
| | c_floor=args.max_complexity, |
| | s_floor=args.min_symbiosis, |
| | ) |
| | else: |
| | |
| | from .bit_io import sample_text |
| | result = sample_text( |
| | model, |
| | args.prompt, |
| | max_new_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | ) |
| |
|
| | print(f"\nGenerated text:\n{result}") |
| |
|
| | except Exception as e: |
| | logger.error(f"Inference failed: {e}") |
| | sys.exit(1) |
| |
|
| |
|
| | def dashboard_cli() -> None: |
| | """CLI entry point for BitTransformerLM dashboard.""" |
| | parser = BitTransformerCLI.create_standard_parser( |
| | "BitTransformerLM Dashboard", |
| | ["io"] |
| | ) |
| | parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host") |
| | parser.add_argument("--port", type=int, default=7860, help="Dashboard port") |
| | parser.add_argument("--share", action="store_true", help="Create public link") |
| | args = parser.parse_args() |
| |
|
| | setup_logging(args.log_level) |
| | logger = logging.getLogger(__name__) |
| |
|
| | logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}") |
| |
|
| | try: |
| | run_dashboard( |
| | host=args.host, |
| | port=args.port, |
| | share=args.share, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Dashboard failed to start: {e}") |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import os |
| | script_name = os.path.basename(sys.argv[0]) |
| |
|
| | if "train" in script_name: |
| | train_cli() |
| | elif "infer" in script_name: |
| | infer_cli() |
| | elif "dashboard" in script_name: |
| | dashboard_cli() |
| | else: |
| | print("Available commands:") |
| | print(" bit-transformer-train - Train a BitTransformerLM model") |
| | print(" bit-transformer-infer - Run inference with a trained model") |
| | print(" bit-transformer-dashboard - Launch interactive dashboard") |
| | sys.exit(1) |