File size: 7,814 Bytes
2f70b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Command-line interface entry points for BitTransformerLM."""

import sys
import logging
from pathlib import Path
from typing import Optional

import torch

from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI
from .config import (
    ExperimentConfig,
    ModelConfig,
    TrainingConfig,
    SafetyConfig,
    DataConfig,
    get_small_config,
    get_medium_config,
    get_large_config,
)
from .model import BitTransformerLM, diffusion_inference
from .training import train_loop
from .bit_io import text_to_bits, bits_to_text, infer_text
from .utils import save_model, load_model
from .dashboard_app import run_dashboard


def setup_logging(level: str = "INFO") -> None:
    """Setup logging configuration."""
    logging.basicConfig(
        level=getattr(logging, level.upper()),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.StreamHandler(sys.stdout),
        ],
    )


def train_cli() -> None:
    """CLI entry point for training BitTransformerLM models."""
    parser = create_training_parser()
    args = parser.parse_args()

    setup_logging(args.log_level)
    logger = logging.getLogger(__name__)

    # Get preset configuration if specified
    if args.model_size == "small":
        config = get_small_config()
    elif args.model_size == "medium":
        config = get_medium_config()
    elif args.model_size == "large":
        config = get_large_config()
    else:
        config = ExperimentConfig()

    # Override with command line arguments
    config.model.d_model = args.d_model
    config.model.nhead = args.num_heads
    config.model.num_layers = args.num_layers
    config.model.max_seq_len = args.max_seq_len

    config.training.epochs = args.epochs
    config.training.batch_size = args.batch_size
    config.training.learning_rate = args.learning_rate
    config.training.weight_decay = args.weight_decay
    config.training.gradient_clip_val = args.grad_clip
    config.training.warmup_steps = args.warmup_steps
    config.training.amp = args.use_amp
    config.training.compile_model = args.compile_model

    config.safety.k_threshold = args.min_negentropy
    config.safety.c_threshold = args.max_complexity
    config.safety.s_threshold = args.min_symbiosis
    config.safety.enable_safety = args.enable_safety_gates

    config.data.dataset_path = Path(args.input_path) if args.input_path else None
    config.data.max_sequence_length = args.seq_length
    config.data.num_workers = args.num_workers

    config.output_dir = Path(args.output_path)
    config.seed = args.seed

    # Set device
    if torch.cuda.is_available():
        config.device = "cuda"
    else:
        config.device = "cpu"

    logger.info(f"Starting training with config: {config.experiment_name}")
    logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H")
    logger.info(f"Device: {config.device}")

    # Create model
    model = BitTransformerLM(**config.model.to_dict())
    model = model.to(config.device)

    # Create synthetic dataset for demonstration
    logger.info("Creating synthetic training data...")
    torch.manual_seed(config.seed)
    data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length))

    # Train model
    logger.info("Starting training...")
    try:
        train_loop(
            model,
            data,
            epochs=config.training.epochs,
            batch_size=config.training.batch_size,
            amp=config.training.amp,
            compile_model=config.training.compile_model,
            log=True,
        )

        # Save model
        save_path = config.output_dir / "model_final.pt"
        save_model(model, save_path)
        logger.info(f"Model saved to {save_path}")

    except Exception as e:
        logger.error(f"Training failed: {e}")
        sys.exit(1)


def infer_cli() -> None:
    """CLI entry point for BitTransformerLM inference."""
    parser = create_inference_parser()
    parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
    parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate")
    parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
    parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode")
    args = parser.parse_args()

    setup_logging(args.log_level)
    logger = logging.getLogger(__name__)

    # Load model
    if not Path(args.weights_path).exists():
        logger.error(f"Model weights not found at {args.weights_path}")
        sys.exit(1)

    logger.info(f"Loading model from {args.weights_path}")
    model = load_model(args.weights_path)
    model.eval()

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    logger.info(f"Model loaded on {device}")
    logger.info(f"Prompt: {args.prompt}")

    try:
        if args.use_diffusion:
            # Diffusion inference
            logger.info("Using diffusion inference mode")
            prompt_bits = text_to_bits(args.prompt)
            length = len(prompt_bits) + args.max_tokens * 9  # Approximate

            generated_bits = diffusion_inference(
                model,
                length=length,
                steps=args.diffusion_steps,
                schedule=args.noise_schedule,
            )

            result = bits_to_text(generated_bits[0].tolist())

        else:
            # Standard autoregressive inference with safety
            if args.enable_safety_gates:
                result = infer_text(
                    model,
                    args.prompt,
                    c_floor=args.max_complexity,
                    s_floor=args.min_symbiosis,
                )
            else:
                # Simple generation without safety gates
                from .bit_io import sample_text
                result = sample_text(
                    model,
                    args.prompt,
                    max_new_tokens=args.max_tokens,
                    temperature=args.temperature,
                )

        print(f"\nGenerated text:\n{result}")

    except Exception as e:
        logger.error(f"Inference failed: {e}")
        sys.exit(1)


def dashboard_cli() -> None:
    """CLI entry point for BitTransformerLM dashboard."""
    parser = BitTransformerCLI.create_standard_parser(
        "BitTransformerLM Dashboard",
        ["io"]
    )
    parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host")
    parser.add_argument("--port", type=int, default=7860, help="Dashboard port")
    parser.add_argument("--share", action="store_true", help="Create public link")
    args = parser.parse_args()

    setup_logging(args.log_level)
    logger = logging.getLogger(__name__)

    logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}")

    try:
        run_dashboard(
            host=args.host,
            port=args.port,
            share=args.share,
        )
    except Exception as e:
        logger.error(f"Dashboard failed to start: {e}")
        sys.exit(1)


if __name__ == "__main__":
    # Simple dispatcher based on script name
    import os
    script_name = os.path.basename(sys.argv[0])

    if "train" in script_name:
        train_cli()
    elif "infer" in script_name:
        infer_cli()
    elif "dashboard" in script_name:
        dashboard_cli()
    else:
        print("Available commands:")
        print("  bit-transformer-train    - Train a BitTransformerLM model")
        print("  bit-transformer-infer    - Run inference with a trained model")
        print("  bit-transformer-dashboard - Launch interactive dashboard")
        sys.exit(1)