WCNegentropy commited on
Commit
b08919a
·
verified ·
1 Parent(s): 2ca4b28

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. bit_transformer/__init__.py +114 -56
bit_transformer/__init__.py CHANGED
@@ -1,23 +1,18 @@
 
 
 
1
  from .model import (
2
- PositionalEncoding,
3
  BitTransformerLM,
 
4
  ReversibleLoggingTransformerEncoderLayer,
5
- example_usage,
6
  example_training_step,
 
7
  infer_long_sequence,
8
- diffusion_inference,
9
  )
10
- from .telemetry import TelemetrySynthesizer, detect_metric_drift
11
- from .dashboard import plot_telemetry
12
- from .dashboard_app import run_dashboard
13
- from .collapse import collapse_submodel, save_distilled_model
14
- from .safety import hil_safe_inference, demo_hil_safety, safe_sample_with_retry
15
- from .bit_io import (
16
- text_to_bits,
17
- bits_to_text,
18
- infer_text,
19
- )
20
- from .parity import enforce_parity
21
  from .compression import (
22
  compress_bits,
23
  decompress_bits,
@@ -25,62 +20,125 @@ from .compression import (
25
  pack_bits,
26
  unpack_bits,
27
  )
28
- from .distributed import wrap_fsdp, make_pipeline
29
- from .optimization import configure_optimizer, adjust_learning_rate
 
 
 
 
 
 
 
30
  from .scale import expand_model
31
- from .distil import distill_step, TelemetryLog
32
- from .quantization import (
33
- quantize_dynamic,
34
- prepare_qat_fx,
35
- convert_qat_fx,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- from .training import train_loop
38
- from .utils import save_model, load_model, set_dropout
39
- from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
 
 
 
 
 
 
40
  from .torch_utils import cpu_autocast
 
41
 
42
  __all__ = [
43
- "PositionalEncoding",
44
  "BitTransformerLM",
 
45
  "ReversibleLoggingTransformerEncoderLayer",
46
- "example_usage",
47
  "example_training_step",
48
- "TelemetrySynthesizer",
49
- "detect_metric_drift",
50
- "collapse_submodel",
51
- "save_distilled_model",
52
- "hil_safe_inference",
53
- "demo_hil_safety",
54
- "safe_sample_with_retry",
55
- "text_to_bits",
56
  "bits_to_text",
57
- "infer_text",
58
- "enforce_parity",
59
- "plot_telemetry",
60
- "run_dashboard",
61
- "configure_optimizer",
62
- "adjust_learning_rate",
63
- "expand_model",
64
- "distill_step",
65
- "TelemetryLog",
66
- "quantize_dynamic",
67
- "prepare_qat_fx",
68
- "convert_qat_fx",
69
- "train_loop",
70
- "wrap_fsdp",
71
- "make_pipeline",
72
  "compress_bits",
73
  "decompress_bits",
 
 
74
  "model_output_decompress",
75
  "pack_bits",
 
76
  "unpack_bits",
77
- "infer_long_sequence",
78
- "diffusion_inference",
79
- "save_model",
80
- "load_model",
81
- "set_dropout",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  "hf_login",
 
 
 
83
  "save_checkpoint",
84
- "download_checkpoint",
85
- "cpu_autocast",
86
  ]
 
1
+ """BitTransformerLM: Bit-native transformer with reversible layers and telemetry."""
2
+
3
+ # Core model components
4
  from .model import (
 
5
  BitTransformerLM,
6
+ PositionalEncoding,
7
  ReversibleLoggingTransformerEncoderLayer,
8
+ diffusion_inference,
9
  example_training_step,
10
+ example_usage,
11
  infer_long_sequence,
 
12
  )
13
+
14
+ # I/O and data processing
15
+ from .bit_io import bits_to_text, infer_text, text_to_bits
 
 
 
 
 
 
 
 
16
  from .compression import (
17
  compress_bits,
18
  decompress_bits,
 
20
  pack_bits,
21
  unpack_bits,
22
  )
23
+ from .parity import enforce_parity
24
+
25
+ # Training and optimization
26
+ from .optimization import adjust_learning_rate, configure_optimizer
27
+ from .training import train_loop
28
+
29
+ # Model scaling and distillation
30
+ from .collapse import collapse_submodel, save_distilled_model
31
+ from .distil import TelemetryLog, distill_step
32
  from .scale import expand_model
33
+
34
+ # Distributed computing
35
+ from .distributed import make_pipeline, wrap_fsdp
36
+
37
+ # Quantization support
38
+ from .quantization import convert_qat_fx, prepare_qat_fx, quantize_dynamic
39
+
40
+ # Safety and monitoring
41
+ from .safety import demo_hil_safety, hil_safe_inference, safe_sample_with_retry
42
+ from .telemetry import TelemetrySynthesizer, detect_metric_drift
43
+
44
+ # Configuration management
45
+ from .config import (
46
+ DataConfig,
47
+ ExperimentConfig,
48
+ ModelConfig,
49
+ SafetyConfig,
50
+ TrainingConfig,
51
+ get_config_from_env,
52
+ get_large_config,
53
+ get_medium_config,
54
+ get_small_config,
55
  )
56
+
57
+ # Command-line interface
58
+ from .cli import dashboard_cli, infer_cli, train_cli
59
+ from .cli_standards import BitTransformerCLI
60
+
61
+ # Visualization and utilities
62
+ from .dashboard import plot_telemetry
63
+ from .dashboard_app import run_dashboard
64
+ from .hf_checkpoint import download_checkpoint, hf_login, save_checkpoint
65
  from .torch_utils import cpu_autocast
66
+ from .utils import load_model, save_model, set_dropout
67
 
68
  __all__ = [
69
+ # Core model components
70
  "BitTransformerLM",
71
+ "PositionalEncoding",
72
  "ReversibleLoggingTransformerEncoderLayer",
73
+ "diffusion_inference",
74
  "example_training_step",
75
+ "example_usage",
76
+ "infer_long_sequence",
77
+
78
+ # I/O and data processing
 
 
 
 
79
  "bits_to_text",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  "compress_bits",
81
  "decompress_bits",
82
+ "enforce_parity",
83
+ "infer_text",
84
  "model_output_decompress",
85
  "pack_bits",
86
+ "text_to_bits",
87
  "unpack_bits",
88
+
89
+ # Training and optimization
90
+ "adjust_learning_rate",
91
+ "configure_optimizer",
92
+ "train_loop",
93
+
94
+ # Model scaling and distillation
95
+ "collapse_submodel",
96
+ "distill_step",
97
+ "expand_model",
98
+ "save_distilled_model",
99
+ "TelemetryLog",
100
+
101
+ # Distributed computing
102
+ "make_pipeline",
103
+ "wrap_fsdp",
104
+
105
+ # Quantization support
106
+ "convert_qat_fx",
107
+ "prepare_qat_fx",
108
+ "quantize_dynamic",
109
+
110
+ # Safety and monitoring
111
+ "demo_hil_safety",
112
+ "detect_metric_drift",
113
+ "hil_safe_inference",
114
+ "safe_sample_with_retry",
115
+ "TelemetrySynthesizer",
116
+
117
+ # Configuration management
118
+ "DataConfig",
119
+ "ExperimentConfig",
120
+ "get_config_from_env",
121
+ "get_large_config",
122
+ "get_medium_config",
123
+ "get_small_config",
124
+ "ModelConfig",
125
+ "SafetyConfig",
126
+ "TrainingConfig",
127
+
128
+ # Command-line interface
129
+ "BitTransformerCLI",
130
+ "dashboard_cli",
131
+ "infer_cli",
132
+ "train_cli",
133
+
134
+ # Visualization and utilities
135
+ "cpu_autocast",
136
+ "download_checkpoint",
137
  "hf_login",
138
+ "load_model",
139
+ "plot_telemetry",
140
+ "run_dashboard",
141
  "save_checkpoint",
142
+ "save_model",
143
+ "set_dropout",
144
  ]