WCNegentropy commited on
Commit
5a75fec
·
verified ·
1 Parent(s): d16c938

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files

BitTransformerLM refined with ML engineering best practices:

✅ **Organized Codebase Structure**
- Cleaned up 30+ scattered scripts into organized directories
- Standardized imports and docstring formatting
- Consolidated configuration management
- Professional package metadata

✅ **Enhanced Developer Experience**
- Comprehensive CLI interface with standardized arguments
- Type-safe configuration system with presets
- Improved error handling and logging
- Better modular organization

✅ **Production Quality**
- PyProject.toml with proper dependencies and tooling
- Consistent code formatting and documentation
- Maintainable directory structure
- Ready for serious development and research

The bit-native transformer architecture with reversible layers, safety telemetry,
and distributed training capabilities is now properly packaged for research use.

CLAUDE.md DELETED
@@ -1,404 +0,0 @@
1
- # BitTransformerLM Claude Code Integration Guide
2
-
3
- ## Overview
4
-
5
- BitTransformerLM is optimally designed for use with [Claude Code](https://claude.ai/code), providing AI-assisted setup, development, and research workflows. This document provides guidelines for working with BitTransformerLM in Claude Code and standalone development.
6
-
7
- ## Why Claude Code?
8
-
9
- BitTransformerLM's unique bit-native architecture has several complexities that Claude Code can help navigate:
10
-
11
- - **Complex Architecture**: Understanding bit-level processing, reversible layers, and safety telemetry
12
- - **Parameter Tuning**: Optimizing model configurations for different use cases
13
- - **Safety Monitoring**: Interpreting K/C/S metrics and configuring safety gates
14
- - **Distributed Training**: Setting up FSDP and pipeline parallelism correctly
15
- - **Debugging**: Identifying issues specific to bit-native processing
16
-
17
- Claude Code understands these nuances and can provide real-time assistance.
18
-
19
- ---
20
-
21
- ## Repository Scope and Architecture
22
-
23
- ### Core Capabilities
24
- BitTransformerLM implements bit-native language modeling with:
25
- - **Bit-Native Processing**: Direct binary sequence modeling with parity protection
26
- - **Reversible Layers**: Memory-efficient transformer blocks that save ~50% memory
27
- - **Safety Telemetry**: Real-time K/C/S (Negentropy/Complexity/Symbiosis) monitoring
28
- - **Diffusion Mode**: Bidirectional denoising with multiple noise schedules
29
- - **Progressive Scaling**: Automatic model expansion based on validation performance
30
- - **Distributed Training**: FSDP and pipeline parallelism for large-scale training
31
- - **Interactive Dashboard**: Real-time training control and visualization
32
-
33
- ### Experimental Status
34
- **Important**: BitTransformerLM is experimental research software requiring:
35
- - Rigorous baseline comparisons against standard transformers
36
- - Validation on established language modeling benchmarks
37
- - Statistical significance testing across multiple runs
38
- - Careful interpretation of safety metrics and claims
39
-
40
- ---
41
-
42
- ## Environment Setup
43
-
44
- ### Requirements
45
- - **Python 3.10+** (required for modern PyTorch features)
46
- - **PyTorch 2.7.1+** with appropriate CUDA support if using GPUs
47
-
48
- ### Installation Options
49
-
50
- #### CPU-Only Installation
51
- ```bash
52
- pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
53
- ```
54
-
55
- #### GPU Installation
56
- ```bash
57
- pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118
58
- pip install -r requirements.txt
59
- ```
60
-
61
- #### Claude Code Assisted Setup
62
- When using Claude Code, simply ask for:
63
- - "Help me set up BitTransformerLM for my system"
64
- - "Configure BitTransformerLM for GPU training"
65
- - "Set up a development environment for bit-native language modeling"
66
-
67
- Claude Code will guide you through hardware detection, dependency installation, and initial configuration.
68
-
69
- ---
70
-
71
- ## Repository Structure
72
-
73
- ```
74
- BitTransformerLM/
75
- ├── bit_transformer/ # Core package
76
- │ ├── model.py # BitTransformerLM architecture
77
- │ ├── telemetry.py # K/C/S safety metrics
78
- │ ├── safety.py # Safety gates and monitoring
79
- │ ├── bit_io.py # Text ↔ bits conversion
80
- │ ├── compression.py # Run-length encoding
81
- │ ├── training.py # Training utilities
82
- │ ├── distributed.py # FSDP and pipeline parallelism
83
- │ ├── dashboard_app.py # Interactive web dashboard
84
- │ ├── quantization.py # INT8/4-bit quantization
85
- │ └── [other modules...] # Additional functionality
86
- ├── tests/ # Test suite and results
87
- ├── example.py # Basic usage example
88
- ├── unified_workflow.py # Main training script
89
- ├── mcp_server.py # Management Control Protocol server
90
- ├── USER_GUIDE.md # Comprehensive user documentation
91
- └── [other scripts...] # Utilities and examples
92
- ```
93
-
94
- ---
95
-
96
- ## Development Workflow with Claude Code
97
-
98
- ### Getting Started
99
-
100
- 1. **Initial Setup**
101
- ```
102
- "Help me understand BitTransformerLM's architecture"
103
- "Create a simple training script for bit-native language modeling"
104
- "Explain the difference between causal and diffusion modes"
105
- ```
106
-
107
- 2. **Model Configuration**
108
- ```
109
- "Configure a BitTransformerLM for [my specific use case]"
110
- "What are optimal hyperparameters for a [size] model?"
111
- "Help me enable reversible layers and gradient checkpointing"
112
- ```
113
-
114
- 3. **Training and Monitoring**
115
- ```
116
- "Set up distributed training with FSDP"
117
- "Interpret these K/C/S telemetry values: K=0.3, C=0.6, S=0.4"
118
- "Debug this memory error during training"
119
- ```
120
-
121
- ### Claude Code Advantages
122
-
123
- **Real-time Assistance**: Get immediate help with:
124
- - Parameter configuration and tuning
125
- - Error diagnosis and resolution
126
- - Architecture modification and experimentation
127
- - Safety metric interpretation
128
- - Performance optimization
129
-
130
- **Context-Aware Suggestions**: Claude Code understands:
131
- - BitTransformerLM's unique bit-native processing
132
- - The relationship between safety metrics
133
- - Memory optimization strategies
134
- - Distributed training complexities
135
-
136
- ---
137
-
138
- ## Key Commands and Workflows
139
-
140
- ### Basic Usage
141
- ```bash
142
- # Run simple example
143
- python example.py
144
-
145
- # Launch interactive dashboard
146
- python unified_workflow.py --dashboard
147
-
148
- # Train with diffusion mode
149
- python unified_workflow.py --diffusion --diffusion-steps 8 --dataset-size 32
150
- ```
151
-
152
- ### Advanced Training
153
- ```bash
154
- # Distributed training with FSDP
155
- python unified_workflow.py --distributed --batch-size 2 --epochs 10
156
-
157
- # Mixed precision with quantization
158
- python unified_workflow.py --amp --qat
159
-
160
- # Progressive scaling with curriculum learning
161
- python unified_workflow.py --progressive --diffusion-curriculum
162
- ```
163
-
164
- ### Dashboard and Monitoring
165
- ```bash
166
- # Start MCP server and dashboard
167
- python mcp_server.py &
168
- MCP_SERVER_ADDR=http://127.0.0.1:7000 python -m bit_transformer.dashboard_app
169
- ```
170
-
171
- **Dashboard Features:**
172
- - Real-time telemetry visualization
173
- - Interactive model configuration
174
- - HuggingFace checkpoint management
175
- - Safe inference testing interface
176
-
177
- ---
178
-
179
- ## Safety and Telemetry
180
-
181
- ### Core Metrics
182
-
183
- | Metric | Full Name | Range | Interpretation |
184
- |--------|-----------|-------|----------------|
185
- | **K** | Negentropy | 0-1 | Information content (0=noise, 1=ordered) |
186
- | **C** | LZ Complexity | 0-1 | Pattern complexity (higher=more complex) |
187
- | **S** | Symbiosis | 0-1 | Alignment with reference (higher=better) |
188
-
189
- ### Using with Claude Code
190
-
191
- ```
192
- "Explain what K=0.2, C=0.8, S=0.3 means for my model"
193
- "Configure safety gates for production use"
194
- "My model is generating repetitive output, what safety metrics should I check?"
195
- "Set up drift detection for telemetry monitoring"
196
- ```
197
-
198
- Claude Code can help interpret these metrics in context and suggest appropriate safety thresholds.
199
-
200
- ### Safety Gate Configuration
201
- ```python
202
- from bit_transformer.safety import SafetyGate
203
-
204
- # Production-ready safety gate
205
- gate = SafetyGate(
206
- c_floor=0.3, # Minimum complexity
207
- s_floor=0.5, # Minimum symbiosis
208
- decay=0.9, # EMA decay factor
209
- burn_in=10 # Steps before gating starts
210
- )
211
- ```
212
-
213
- ---
214
-
215
- ## Best Practices for Claude Code Development
216
-
217
- ### 1. **Always Validate Research Claims**
218
- Ask Claude Code to help you:
219
- - Set up proper baseline comparisons
220
- - Design statistical significance tests
221
- - Implement evaluation on standard benchmarks
222
- - Document limitations and assumptions
223
-
224
- ### 2. **Use Progressive Development**
225
- ```
226
- "Start me with a minimal BitTransformerLM example"
227
- "Now add safety monitoring"
228
- "Scale up to distributed training"
229
- "Add diffusion mode capabilities"
230
- ```
231
-
232
- ### 3. **Leverage Claude Code for Architecture Understanding**
233
- ```
234
- "Explain how reversible layers save memory"
235
- "Walk me through the bit encoding process"
236
- "How does the safety telemetry system work?"
237
- "Compare BitTransformerLM to standard transformers"
238
- ```
239
-
240
- ### 4. **Get Help with Complex Configurations**
241
- ```python
242
- # Ask Claude Code to help configure models like:
243
- model = BitTransformerLM(
244
- d_model=1024, # Claude Code can suggest optimal values
245
- nhead=16, # Based on your hardware and use case
246
- num_layers=20,
247
- dim_feedforward=4096,
248
- max_seq_len=2048,
249
- reversible=True, # Memory optimization
250
- use_checkpoint=True, # Gradient checkpointing
251
- chunk_size=256, # Attention chunking
252
- lambda_K=0.1, # Regularization weights
253
- lambda_C=0.1,
254
- lambda_S=0.1
255
- )
256
- ```
257
-
258
- ---
259
-
260
- ## Development Guidelines
261
-
262
- ### Code Style
263
- - **Functions**: `snake_case` (e.g., `train_loop`, `safe_inference`)
264
- - **Classes**: `CamelCase` (e.g., `BitTransformerLM`, `SafetyGate`)
265
- - **Constants**: `UPPER_SNAKE_CASE` (e.g., `MAX_SEQ_LEN`)
266
- - **Keep functions under 300 lines** and minimize deep nesting
267
-
268
- ### Security and Safety
269
- - **Never reintroduce deprecated `/exec` endpoint**
270
- - **Always use safety gates in production**
271
- - **Validate all user inputs** in dashboard and API endpoints
272
- - **Monitor telemetry metrics** for anomalous behavior
273
- - **Use `cpu_autocast()` helper** instead of direct `torch.amp.autocast`
274
-
275
- ### Memory Management
276
- ```python
277
- # Good: Memory-efficient configuration
278
- model = BitTransformerLM(
279
- reversible=True, # Enable reversible layers
280
- use_checkpoint=True, # Gradient checkpointing
281
- chunk_size=128, # Chunked attention
282
- full_attn_logging=False # Skip full attention reconstruction
283
- )
284
-
285
- # Training with memory optimizations
286
- train_loop(
287
- model, data,
288
- amp=True, # Mixed precision
289
- accum_steps=4, # Gradient accumulation
290
- compile_model=True # torch.compile optimization
291
- )
292
- ```
293
-
294
- ### Testing and Validation
295
- ```bash
296
- # Run tests after changes
297
- pytest -q
298
-
299
- # Model evaluation modes
300
- model.train() # For training
301
- model.eval() # For inference/evaluation
302
- set_dropout(model, 0.0) # Disable dropout for reproducible results
303
- ```
304
-
305
- ---
306
-
307
- ## Getting Help from Claude Code
308
-
309
- ### Specific Areas Where Claude Code Excels
310
-
311
- 1. **Architecture Design**
312
- - "Design a BitTransformerLM architecture for [specific task]"
313
- - "Optimize this model configuration for memory efficiency"
314
- - "Explain the trade-offs between reversible and standard layers"
315
-
316
- 2. **Training Optimization**
317
- - "My training is running out of memory, help optimize"
318
- - "Configure distributed training for 4 GPUs"
319
- - "Set up a training curriculum for bit-native language modeling"
320
-
321
- 3. **Safety and Monitoring**
322
- - "Interpret these telemetry readings and suggest adjustments"
323
- - "Set up production-ready safety monitoring"
324
- - "Debug why my safety gate is triggering constantly"
325
-
326
- 4. **Research and Evaluation**
327
- - "Design a rigorous evaluation comparing BitTransformerLM to GPT-2"
328
- - "Set up statistical significance testing for my experiments"
329
- - "Help me write up my research findings honestly"
330
-
331
- ### Sample Claude Code Interactions
332
-
333
- ```
334
- User: "I'm getting OOM errors training a 500M parameter BitTransformerLM"
335
-
336
- Claude Code: "Let me help optimize your memory usage. Here's a configuration
337
- that should work better for your model size:
338
-
339
- model = BitTransformerLM(
340
- d_model=768,
341
- nhead=12,
342
- num_layers=12,
343
- reversible=True, # Critical for large models
344
- use_checkpoint=True, # Trade compute for memory
345
- chunk_size=64, # Reduce attention memory
346
- full_attn_logging=False # Skip expensive logging
347
- )
348
-
349
- train_loop(
350
- model, data,
351
- batch_size=1, # Small batch size
352
- accum_steps=16, # Maintain effective batch size
353
- amp=True # Mixed precision training
354
- )
355
-
356
- This should reduce memory usage by ~60% compared to standard configuration."
357
- ```
358
-
359
- ---
360
-
361
- ## Licensing and Distribution
362
-
363
- BitTransformerLM is available under dual licensing:
364
- - **Open Source**: AGPLv3 for research and open source use
365
- - **Commercial**: Contact **contact@wcnegentropy.com** for commercial licensing
366
-
367
- When working with Claude Code, ensure compliance with the AGPLv3 license for any derivatives or modifications you create.
368
-
369
- ---
370
-
371
- ## Research Integrity
372
-
373
- **Important Reminder**: BitTransformerLM is experimental research software. When using Claude Code:
374
-
375
- 1. **Always validate claims** through proper baseline comparisons
376
- 2. **Document limitations** honestly in any publications or reports
377
- 3. **Use statistical significance testing** for any performance claims
378
- 4. **Follow established ML research best practices**
379
- 5. **Share negative results** as well as positive ones
380
-
381
- Claude Code can help you design rigorous experiments and avoid common pitfalls in ML research.
382
-
383
- ---
384
-
385
- ## Support and Community
386
-
387
- ### Getting Help
388
- - **Claude Code**: Real-time AI assistance with BitTransformerLM
389
- - **GitHub Issues**: Bug reports and feature requests
390
- - **Discussions**: Community questions and sharing
391
- - **User Guide**: Comprehensive documentation (`USER_GUIDE.md`)
392
- - **Project Overview**: Complete project information (`ABOUTME.md`)
393
-
394
- ### Contributing
395
- When contributing to BitTransformerLM:
396
- 1. Use Claude Code to ensure code quality and consistency
397
- 2. Follow the development guidelines in this document
398
- 3. Add tests for new functionality
399
- 4. Update documentation as needed
400
- 5. Ensure all safety and security practices are followed
401
-
402
- ---
403
-
404
- **BitTransformerLM + Claude Code provides a powerful combination for exploring bit-native language modeling with AI assistance. Start experimenting responsibly and share your findings with the research community!** 🤖✨
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
EMPIRICAL_VALIDATION.md DELETED
@@ -1,147 +0,0 @@
1
- # BitTransformerLM Empirical Validation Report
2
-
3
- **Report Date:** August 2025
4
- **Data Sources:** Test results, training logs, forensic analysis
5
- **Validation Level:** Initial experimental validation only
6
-
7
- ## Validated Claims vs Empirical Evidence
8
-
9
- This document provides a rigorous assessment of what has been empirically validated versus what remains unsubstantiated or requires further testing.
10
-
11
- ### ✅ **EMPIRICALLY VALIDATED CLAIMS**
12
-
13
- #### Architecture Implementation
14
- - **✓ Bit-native processing:** Successfully processes binary sequences (0/1) as input
15
- - *Evidence:* Successful training on bit sequences from parity-encoded text
16
- - *Test cases:* Both 793K and 771M parameter models
17
- - **✓ Reversible layers:** Mathematical reversible transformer blocks implemented and functional
18
- - *Evidence:* Models train successfully with reversible=True configuration
19
- - *Measured benefit:* Implementation complete, memory benefit theoretical (not measured vs baseline)
20
- - **✓ Multi-head attention:** Adapted for bit embeddings with configurable heads (2-28 tested)
21
- - *Evidence:* Models train with various attention head configurations
22
-
23
- #### Safety and Telemetry Systems
24
- - **✓ K/C/S metric computation:** Negentropy, LZ complexity, symbiosis calculations functional
25
- - *Evidence:* Metrics computed during training: K≈0.0013, C≈0.52, S≈0.46
26
- - *Limitation:* Values based on limited training data, effectiveness unvalidated
27
- - **✓ Real-time monitoring:** Dashboard displays metrics during training
28
- - *Evidence:* Working web interface with live metric updates
29
- - **✓ Safety gates:** EMA-smoothed thresholds prevent generation below configured limits
30
- - *Evidence:* Implementation present, triggers when thresholds violated
31
-
32
- #### Training Infrastructure
33
- - **✓ FSDP implementation:** Fully Sharded Data Parallel training code present
34
- - *Evidence:* Successfully trained 771M parameter model
35
- - *Scale limit:* Only tested up to 771M parameters, not billion+ scale
36
- - **✓ Mixed precision:** FP16/BF16 training with CPU autocast support
37
- - *Evidence:* Training logs show mixed precision usage
38
- - **✓ Progressive scaling:** Architecture expansion based on performance metrics
39
- - *Evidence:* Code implementation validates, mechanism functional
40
- - **✓ Quantization support:** Dynamic INT8 and experimental 4-bit QAT
41
- - *Evidence:* Implementation present, basic functionality validated
42
-
43
- #### Training Results
44
- - **✓ Small-scale convergence:** 793K parameter model converges on toy data
45
- - *Evidence:* Loss: 0.779 → 0.571 over 5 epochs (0.21s training)
46
- - *Limitation:* Toy dataset (4 samples, 16 sequence length)
47
- - **✓ Medium-scale training:** 771M parameter model trains without crashing
48
- - *Evidence:* 5 epochs completed, loss reduction: 11.84 → 5.35
49
- - *Limitation:* Minimal dataset (5 samples with padding), insufficient for language modeling assessment
50
- - **✓ Inference generation:** Models generate bit sequences successfully
51
- - *Evidence:* 100% success rate on test prompts in both configurations
52
-
53
- ### ⚠️ **UNVALIDATED OR OVERSTATED CLAIMS**
54
-
55
- #### Performance and Efficiency
56
- - **⚠️ "50%+ memory reduction":** Theoretical based on reversible architecture design
57
- - *Status:* No empirical measurement vs baseline transformers
58
- - *Required:* Controlled comparison with equivalent standard models
59
- - **⚠️ "Memory-efficient processing":** Implementation suggests efficiency but not measured
60
- - *Status:* No quantitative comparison to baseline memory usage
61
- - *Required:* Systematic memory profiling vs standard transformers
62
- - **⚠️ "Superior scaling behavior":** No evidence of scaling advantages
63
- - *Status:* Only tested up to 771M parameters on toy datasets
64
- - *Required:* Large-scale comparative studies vs standard models
65
-
66
- #### Capability Claims
67
- - **⚠️ "Language modeling capability":** Training on insufficient data for assessment
68
- - *Status:* Models trained only on toy datasets (4-5 samples)
69
- - *Required:* Training and evaluation on standard language modeling benchmarks
70
- - **⚠️ "Production-ready system":** Experimental status contradicts production claims
71
- - *Status:* No baseline comparisons or real-world evaluation
72
- - *Required:* Rigorous validation against established benchmarks
73
- - **⚠️ "Revolutionary/groundbreaking":** Marketing language not supported by comparative evidence
74
- - *Status:* Novel approach but benefits undemonstrated vs alternatives
75
- - *Required:* Peer review and comparative analysis
76
-
77
- #### Scale and Distribution
78
- - **⚠️ "Billion+ parameter scaling":** Largest validated model is 771M parameters
79
- - *Status:* FSDP code supports larger models but not empirically validated
80
- - *Evidence contradiction:* Forensic analysis shows 771M ≠ 1B despite some claims
81
- - **⚠️ "Multi-GPU efficiency":** Single GPU actually used despite multi-GPU claims
82
- - *Status:* Code supports FSDP but largest training used device_ids=[0] only
83
- - *Required:* True distributed training validation and efficiency measurement
84
-
85
- ### ❌ **REFUTED CLAIMS**
86
-
87
- #### Parameter Count Accuracy
88
- - **✗ "Working 1B Parameter Model":** Actually 771,176,450 parameters (771M)
89
- - *Evidence:* Forensic analysis of model configuration and training logs
90
- - *Discrepancy:* 23% less than claimed 1B parameters
91
- - **✗ "Multi-GPU training":** Actually single GPU training
92
- - *Evidence:* `device_ids=[0]` in configuration, only GPU 0 utilized
93
- - *Misrepresentation:* Claims of 4-GPU training while using single GPU
94
-
95
- ## Empirical Evidence Summary
96
-
97
- ### Training Data Analysis
98
- **Small Model (793K parameters):**
99
- - Dataset: 4 samples, 16 sequence length
100
- - Training time: 0.21 seconds
101
- - Final loss: 0.629, Best loss: 0.571
102
- - **Assessment:** Toy validation only, insufficient for capability claims
103
-
104
- **Large Model (771M parameters):**
105
- - Dataset: 5 text samples with zero-padding
106
- - Training time: 11.47 seconds
107
- - Hardware: Single NVIDIA L4 GPU (15.28 GB peak memory)
108
- - Loss trajectory: Chaotic pattern suggesting insufficient data
109
- - **Assessment:** Technical validation of scale, but inadequate training data
110
-
111
- ### Telemetry Data Analysis
112
- - **K (Negentropy):** 0.0013 (low information content, consistent with limited training data)
113
- - **C (LZ Complexity):** 0.52 (moderate complexity, within expected range)
114
- - **S (Symbiosis):** 0.46 (below optimum, consistent with limited training)
115
- - **Assessment:** Metrics functional but values reflect training data limitations
116
-
117
- ## Required Evidence for Substantiated Claims
118
-
119
- ### For Memory Efficiency Claims
120
- 1. **Controlled Memory Measurement:** Direct comparison with equivalent standard transformers
121
- 2. **Scale Analysis:** Memory usage patterns across different model sizes
122
- 3. **Peak Memory Profiling:** Training and inference memory requirements vs baselines
123
-
124
- ### For Performance Claims
125
- 1. **Standard Benchmarks:** WikiText-103, Penn Treebank, other established datasets
126
- 2. **Multiple Runs:** Statistical significance testing with confidence intervals
127
- 3. **Convergence Analysis:** Long-duration training to true convergence
128
- 4. **Comparative Evaluation:** Head-to-head performance vs standard architectures
129
-
130
- ### For Scaling Claims
131
- 1. **True Large Scale:** >1B parameter models with proper distributed training
132
- 2. **Scaling Laws:** Parameter vs performance relationships compared to baselines
133
- 3. **Efficiency Analysis:** Training cost and time comparisons at scale
134
-
135
- ## Conclusion
136
-
137
- **What is Validated:** BitTransformerLM is a complete, functional experimental implementation of bit-native language modeling with sophisticated monitoring and safety systems.
138
-
139
- **What Requires Validation:** All claims about efficiency, capability, and advantages over standard approaches require rigorous empirical validation through proper baseline comparisons.
140
-
141
- **What is Refuted:** Some historical documentation contained factually incorrect claims about parameter counts and hardware usage, which have been corrected.
142
-
143
- **Research Status:** The implementation provides an excellent foundation for rigorous research evaluation, but requires extensive validation work before any practical claims can be substantiated.
144
-
145
- ---
146
-
147
- *This empirical validation report reflects only what can be verified through available evidence. All claims about advantages, efficiency, or superior performance remain hypotheses requiring systematic investigation through proper ML research methodology.*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
OPEN_SOURCE_LAUNCH.md DELETED
@@ -1,192 +0,0 @@
1
- # BitTransformerLM Open Source Launch
2
-
3
- **Launch Date:** August 2025
4
- **Version:** v0.1.0 (Pre-release)
5
- **Status:** Experimental Research Release
6
-
7
- ## What We're Launching
8
-
9
- BitTransformerLM is an experimental transformer language model that processes text at the bit level rather than using traditional tokenization. This open source release provides a complete research framework for exploring bit-native language modeling approaches.
10
-
11
- ### Key Innovations
12
-
13
- **Bit-Native Architecture:** Processes binary sequences (0/1) directly with custom bit embeddings and positional encodings, enabling fine-grained control over information processing.
14
-
15
- **Reversible Layers:** Implements mathematically reversible transformer blocks that theoretically enable memory-efficient computation by avoiding intermediate activation storage.
16
-
17
- **Safety-First Design:** Built-in real-time telemetry (K/C/S metrics) monitors negentropy, compressibility, and alignment during training and inference with configurable safety gates.
18
-
19
- **Research Infrastructure:** Comprehensive framework including distributed training (FSDP), interactive dashboard, progressive scaling, and extensive testing suite.
20
-
21
- ## What This Release Includes
22
-
23
- ### ✅ **Complete Implementation**
24
- - 57 Python files with 10,699+ lines of research code
25
- - Full transformer architecture adapted for bit-level processing
26
- - FSDP distributed training support (tested to 771M parameters)
27
- - Interactive web dashboard for training control and monitoring
28
- - Comprehensive test suite with automated CI validation
29
- - Mixed precision training with quantization support
30
-
31
- ### ✅ **Validated Functionality**
32
- - Successful training on small (793K) and medium (771M) parameter scales
33
- - Functional safety telemetry and monitoring systems
34
- - Working inference with bit sequence generation
35
- - Progressive scaling and architecture expansion
36
- - Real-time dashboard monitoring and control
37
-
38
- ### ✅ **Development Tools**
39
- - MCP (Management Control Protocol) server for integration
40
- - HuggingFace Hub integration for model sharing
41
- - Docker containerization for reproducible deployment
42
- - CLI tools and example scripts
43
- - Comprehensive documentation and API reference
44
-
45
- ## Important Limitations and Disclaimers
46
-
47
- ### ⚠️ **Research Status**
48
- - **Experimental Implementation:** This is research code exploring a novel approach
49
- - **No Baseline Comparisons:** Has not been rigorously evaluated against standard transformers
50
- - **Limited Training Data:** Validated only on toy datasets insufficient for language modeling assessment
51
- - **Unverified Claims:** Memory efficiency and performance benefits are theoretical until properly measured
52
-
53
- ### ⚠️ **Not Production Ready**
54
- - Requires extensive validation before any production use
55
- - Missing critical baseline evaluations on standard benchmarks
56
- - Training conducted only on minimal datasets (4-5 samples)
57
- - Performance claims relative to standard approaches are unsubstantiated
58
-
59
- ### ⚠️ **Validation Needed**
60
- - Comparative studies vs equivalent standard transformers
61
- - Long-duration training on real language modeling datasets
62
- - Statistical significance testing across multiple runs
63
- - Memory and compute efficiency measurement vs baselines
64
-
65
- ## Intended Use Cases
66
-
67
- ### ✅ **Recommended Research Applications**
68
- - **Academic Research:** Novel architecture exploration and bit-level modeling studies
69
- - **AI Safety Research:** Telemetry system development and safety monitoring research
70
- - **Memory Efficiency Studies:** Reversible architecture investigation and optimization
71
- - **Educational Use:** Learning about transformer internals and experimental architectures
72
-
73
- ### ❌ **Not Recommended**
74
- - Production applications without rigorous validation
75
- - Direct comparison claims without proper baseline studies
76
- - Commercial deployment without extensive testing
77
- - Any use case requiring proven performance advantages
78
-
79
- ## Getting Started
80
-
81
- ### Installation
82
- ```bash
83
- # Clone repository
84
- git clone https://github.com/WCNegentropy/BitTransformerLM.git
85
- cd BitTransformerLM
86
-
87
- # Install dependencies
88
- pip install -r requirements.txt
89
-
90
- # Run basic example
91
- python example.py
92
-
93
- # Launch interactive dashboard
94
- python unified_workflow.py --dashboard
95
- ```
96
-
97
- ### Basic Usage
98
- ```python
99
- from bit_transformer import BitTransformerLM
100
-
101
- # Create model
102
- model = BitTransformerLM(
103
- d_model=64,
104
- nhead=4,
105
- num_layers=2,
106
- dim_feedforward=128,
107
- max_seq_len=64
108
- )
109
-
110
- # Train on bit sequences
111
- bits = torch.randint(0, 2, (batch_size, seq_len))
112
- logits, telemetry = model(bits)
113
- ```
114
-
115
- ## Community and Contributions
116
-
117
- ### How to Contribute
118
- - **Bug Reports:** Use GitHub Issues for reproducible bug reports
119
- - **Feature Requests:** Propose enhancements with clear use cases
120
- - **Pull Requests:** Follow existing code style and include tests
121
- - **Research Results:** Share findings from validation studies and comparisons
122
-
123
- ### Research Collaboration
124
- We encourage researchers to:
125
- - Conduct rigorous baseline comparisons
126
- - Evaluate on standard language modeling benchmarks
127
- - Share results (positive or negative) with the community
128
- - Extend the architecture for specific research questions
129
-
130
- ### Documentation
131
- - **ABOUTME.md:** Quick start and feature overview
132
- - **README.md:** Professional model card with specifications and limitations
133
- - **RESEARCH_STATUS.md:** Current research status and validation needs
134
- - **EMPIRICAL_VALIDATION.md:** What has been validated vs what requires further study
135
-
136
- ## License and Usage Terms
137
-
138
- **Primary License:** AGPLv3 (see LICENSE/LICENSE.txt)
139
- **Additional Terms:** See LICENSE/ directory for complete framework
140
- - Commercial licensing available (see COMMERCIAL_LICENSE.txt)
141
- - Contributor License Agreement required (see CONTRIBUTOR_LICENSE_AGREEMENT.txt)
142
- - Trademark policy and disclaimers included
143
-
144
- ## Future Development
145
-
146
- ### Immediate Priorities
147
- 1. **Rigorous Baseline Studies:** Comprehensive evaluation vs standard transformers
148
- 2. **Standard Dataset Training:** WikiText-103, Penn Treebank evaluation
149
- 3. **Statistical Validation:** Multiple runs with significance testing
150
- 4. **Memory Efficiency Measurement:** Quantitative analysis vs baselines
151
-
152
- ### Research Directions
153
- 1. **Scaling Studies:** True large-scale (1B+ parameter) validation with proper distributed training
154
- 2. **Application Studies:** Identify scenarios where bit-level processing provides advantages
155
- 3. **Safety System Validation:** Evaluate K/C/S telemetry effectiveness across diverse scenarios
156
- 4. **Hardware Optimization:** Custom kernels and neuromorphic computing exploration
157
-
158
- ## Citation
159
-
160
- ```bibtex
161
- @software{bittransformerlm2025,
162
- title={BitTransformerLM: Experimental Bit-Native Transformer Language Model},
163
- author={WCNegentropy Research},
164
- year={2025},
165
- url={https://github.com/WCNegentropy/BitTransformerLM},
166
- version={0.1.0},
167
- note={Experimental research implementation}
168
- }
169
- ```
170
-
171
- ## Contact and Support
172
-
173
- - **Repository:** https://github.com/WCNegentropy/BitTransformerLM
174
- - **Issues:** GitHub Issues for bug reports and technical questions
175
- - **Discussions:** GitHub Discussions for research questions and community discussion
176
- - **License Questions:** See LICENSE/ directory or contact maintainers
177
-
178
- ---
179
-
180
- ## Launch Statement
181
-
182
- We are excited to release BitTransformerLM as an open source research project exploring bit-native language modeling. This implementation represents a complete experimental framework with potential for advancing memory-efficient transformer architectures and interpretable AI systems.
183
-
184
- **Important:** This is experimental research code. While the implementation is complete and functional, it requires extensive validation through proper baseline comparisons before any practical claims can be made. We encourage the research community to help validate (or refute) the potential benefits of this approach through rigorous scientific methodology.
185
-
186
- The future of this project depends on community validation and research. We welcome contributions, comparisons, and honest evaluation of the approach's merits and limitations.
187
-
188
- **Research responsibly. Validate rigorously. Share openly.**
189
-
190
- ---
191
-
192
- *BitTransformerLM v0.1.0 - Experimental Research Release - August 2025*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
RESEARCH_STATUS.md DELETED
@@ -1,140 +0,0 @@
1
- # BitTransformerLM Research Status Report
2
-
3
- **Date:** August 2025
4
- **Status:** Experimental Implementation Complete
5
- **Validation Level:** Pre-baseline Evaluation
6
-
7
- ## Executive Summary
8
-
9
- BitTransformerLM represents a complete experimental implementation of bit-native language modeling with reversible transformer architecture. The project demonstrates the feasibility of the approach and provides a comprehensive research framework. However, the implementation requires rigorous validation against standard baselines before any production considerations.
10
-
11
- ## Current Implementation Status
12
-
13
- ### ✅ **Completed Components**
14
-
15
- **Core Architecture:**
16
- - Bit-native input processing (0/1 binary sequences)
17
- - Reversible transformer layers for memory efficiency
18
- - Multi-head attention adapted for bit-level representations
19
- - Progressive scaling with automatic architecture expansion
20
- - Experimental diffusion mode for bidirectional generation
21
-
22
- **Safety and Monitoring:**
23
- - Real-time telemetry (K/C/S metrics): Negentropy, LZ Complexity, Symbiosis
24
- - Safety gates with EMA smoothing and configurable thresholds
25
- - Metric drift detection and alerting systems
26
- - Human-in-the-loop safe inference with retry mechanisms
27
-
28
- **Training Infrastructure:**
29
- - FSDP distributed training support (validated up to 771M parameters)
30
- - Mixed precision training (FP16/BF16 with CPU autocast)
31
- - Gradient checkpointing for memory efficiency
32
- - Quantization support (dynamic INT8 + experimental 4-bit QAT)
33
- - Chunked attention for long sequence processing
34
-
35
- **Development Tools:**
36
- - Interactive web dashboard for training control and monitoring
37
- - MCP (Management Control Protocol) server for integration
38
- - HuggingFace Hub integration for model sharing
39
- - Comprehensive test suite (11 test modules)
40
- - CI/CD pipeline with automated testing
41
-
42
- ### 📊 **Empirical Results**
43
-
44
- **Small-Scale Validation (793K parameters):**
45
- - Training: Successful convergence on toy dataset (4 samples, 16 seq length)
46
- - Loss reduction: 0.779 → 0.571 in 5 epochs (0.21s training time)
47
- - Inference: 100% success rate on test prompts
48
- - Memory: Minimal resource usage
49
-
50
- **Medium-Scale Validation (771M parameters):**
51
- - Training: 5 epochs on limited dataset (5 samples with padding)
52
- - Hardware: Single GPU with 15.28 GB peak memory usage
53
- - Loss progression: 11.84 → 5.35 (showing learning but on insufficient data)
54
- - Telemetry: K≈0.0013, C≈0.52, S≈0.46 (limited by training data)
55
- - Inference: 100% success on test prompts with bit generation
56
-
57
- ## Critical Limitations and Research Needs
58
-
59
- ### ⚠️ **Validation Gaps**
60
-
61
- **Missing Baseline Comparisons:**
62
- - No systematic evaluation against standard transformer architectures
63
- - No performance comparison on established benchmarks (WikiText, Penn Treebank, etc.)
64
- - No efficiency analysis compared to token-based approaches
65
- - No scaling law establishment relative to conventional models
66
-
67
- **Training Data Limitations:**
68
- - Experiments conducted only on toy datasets insufficient for language modeling
69
- - Largest training used 5 short text samples with heavy zero-padding
70
- - No evaluation on real-world corpora or standard datasets
71
- - Training durations too short to establish genuine convergence patterns
72
-
73
- **Scale Verification Needed:**
74
- - Largest successfully trained model: 771M parameters (not 1B+ as claimed in some docs)
75
- - FSDP distributed training tested but not at true large scale
76
- - Memory efficiency claims need quantitative validation against baselines
77
- - Scalability to billion+ parameter models requires verification
78
-
79
- ### 🔬 **Research Questions Requiring Investigation**
80
-
81
- 1. **Efficiency Claims:** Does bit-native processing provide memory/compute advantages over token-based models of equivalent capacity?
82
-
83
- 2. **Learning Capability:** Can bit-level models achieve comparable performance to standard transformers on language modeling benchmarks?
84
-
85
- 3. **Scaling Behavior:** How do bit-native models scale compared to conventional architectures in terms of parameters, data, and compute?
86
-
87
- 4. **Safety Effectiveness:** Do K/C/S telemetry metrics provide reliable safety monitoring compared to existing approaches?
88
-
89
- 5. **Practical Applications:** What use cases, if any, benefit from bit-level granularity over standard tokenization?
90
-
91
- ## Recommended Research Agenda
92
-
93
- ### Phase 1: Baseline Establishment (High Priority)
94
- 1. **Standard Dataset Evaluation:** Train on WikiText-103, Penn Treebank, other established benchmarks
95
- 2. **Comparative Analysis:** Direct comparison with equivalent-parameter standard transformers
96
- 3. **Statistical Validation:** Multiple runs with significance testing and confidence intervals
97
- 4. **Performance Profiling:** Systematic memory and compute analysis vs baselines
98
-
99
- ### Phase 2: Scaling Studies (Medium Priority)
100
- 1. **True Large-Scale Training:** 1B+ parameter models with proper distributed training
101
- 2. **Convergence Analysis:** Long-duration training to establish learning dynamics
102
- 3. **Scaling Law Investigation:** Parameter vs performance relationships
103
- 4. **Resource Efficiency:** Quantitative memory and compute efficiency analysis
104
-
105
- ### Phase 3: Application Validation (Lower Priority)
106
- 1. **Use Case Analysis:** Identify scenarios where bit-level processing provides advantages
107
- 2. **Safety System Evaluation:** Validate K/C/S metrics on diverse datasets and failure modes
108
- 3. **Production Readiness:** Real-world deployment studies with proper evaluation protocols
109
- 4. **Community Validation:** External evaluation and peer review processes
110
-
111
- ## Technical Debt and Known Issues
112
-
113
- ### Documentation Inconsistencies
114
- - Some historical documentation contains overstated claims (addressed in cleanup)
115
- - Parameter count discrepancies between different documents (corrected)
116
- - Multi-GPU usage claims not matching actual implementation (clarified)
117
-
118
- ### Code Quality
119
- - Security issues identified and resolved (removed `/exec` endpoint)
120
- - Minor import and edge-case bugs identified in audit (fixed)
121
- - Test coverage comprehensive but focused on unit tests vs integration scenarios
122
-
123
- ### Performance Optimization Opportunities
124
- - Vectorization of compression/decompression operations
125
- - Memory optimization for long sequence processing
126
- - Batch processing improvements for training efficiency
127
-
128
- ## Conclusion and Recommendations
129
-
130
- **Current Status:** BitTransformerLM provides a complete, well-engineered experimental framework for bit-native language modeling research. The implementation demonstrates technical feasibility and includes sophisticated monitoring and safety systems.
131
-
132
- **Critical Next Steps:** The project requires rigorous baseline comparisons and statistical validation before any claims about efficiency or capability can be substantiated. The experimental framework is ready for serious research evaluation.
133
-
134
- **Research Potential:** If validation studies demonstrate advantages in specific scenarios, BitTransformerLM could contribute to memory-efficient language modeling and interpretable AI systems. However, these benefits must be rigorously established through proper scientific methodology.
135
-
136
- **Production Readiness:** Not recommended for production use without extensive validation. The experimental nature and lack of baseline comparisons make it unsuitable for anything beyond research applications.
137
-
138
- ---
139
-
140
- *This report reflects the actual technical status based on forensic analysis of implementation, testing results, and documentation. It supersedes any inflated claims in historical documents and provides an honest foundation for future research directions.*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bit_transformer/static/style.css DELETED
@@ -1,93 +0,0 @@
1
- :root {
2
- --primary: #1e40af;
3
- --bg: #f5f6fa;
4
- }
5
-
6
- body {
7
- font-family: Arial, sans-serif;
8
- background-color: var(--bg);
9
- margin: 0;
10
- padding: 0;
11
- line-height: 1.5;
12
- color: #333;
13
- }
14
-
15
- .container {
16
- max-width: 900px;
17
- margin: 0 auto;
18
- padding-bottom: 2rem;
19
- }
20
-
21
- h1 {
22
- text-align: center;
23
- background: var(--primary);
24
- color: #fff;
25
- margin: 0;
26
- padding: 1rem 0;
27
- }
28
-
29
- section {
30
- background: #fff;
31
- margin: 1rem auto;
32
- padding: 1rem 1.5rem;
33
- border-radius: 8px;
34
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
35
- width: 90%;
36
- max-width: 800px;
37
- }
38
-
39
- section h2 {
40
- margin-top: 0;
41
- color: var(--primary);
42
- font-size: 1.25rem;
43
- }
44
-
45
- form {
46
- display: flex;
47
- flex-wrap: wrap;
48
- gap: 0.5rem 1rem;
49
- }
50
-
51
- form input[type="text"],
52
- form input[type="number"],
53
- form textarea {
54
- flex: 1 1 200px;
55
- padding: 0.4em;
56
- border: 1px solid #ccc;
57
- border-radius: 4px;
58
- }
59
-
60
- form button,
61
- button#scaleBtn {
62
- padding: 0.4em 0.8em;
63
- border: none;
64
- background: var(--primary);
65
- color: #fff;
66
- border-radius: 4px;
67
- cursor: pointer;
68
- }
69
-
70
- form button:hover,
71
- button#scaleBtn:hover {
72
- background-color: #1d4ed8;
73
- }
74
-
75
- pre, p#trainOut {
76
- background: #f0f0f0;
77
- padding: 0.5rem;
78
- border-radius: 4px;
79
- overflow-x: auto;
80
- }
81
-
82
- label {
83
- display: flex;
84
- align-items: center;
85
- gap: 0.5rem;
86
- }
87
-
88
- img#plot {
89
- max-width: 100%;
90
- height: auto;
91
- display: block;
92
- margin: auto;
93
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bit_transformer/templates/dashboard.html DELETED
@@ -1,454 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <title>Bit Transformer Dashboard</title>
6
- <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
7
- <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
8
- </head>
9
- <body>
10
- <h1>Bit Transformer Dashboard</h1>
11
- <div class="container">
12
- <section>
13
- <h2>Initialize Model</h2>
14
- <form id="initForm">
15
- d_model: <input type="number" name="d_model" value="{{ defaults.d_model }}" title="Model width (default {{ defaults.d_model }})"><br>
16
- nhead: <input type="number" name="nhead" value="{{ defaults.nhead }}" title="Attention heads (default {{ defaults.nhead }})"><br>
17
- num_layers: <input type="number" name="num_layers" value="{{ defaults.num_layers }}" title="Transformer layers (default {{ defaults.num_layers }})"><br>
18
- dim_feedforward: <input type="number" name="dim_feedforward" value="{{ defaults.dim_feedforward }}" title="Feedforward dim (default {{ defaults.dim_feedforward }})"><br>
19
- max_seq_len: <input type="number" name="max_seq_len" value="{{ defaults.max_seq_len }}" title="Max sequence length (default {{ defaults.max_seq_len }})"><br>
20
- chunk_size: <input type="number" name="chunk_size" title="Chunked attention size"><br>
21
- overlap: <input type="number" name="overlap" value="{{ defaults.overlap }}" title="Sliding window overlap"><br>
22
- Reversible: <input type="checkbox" name="reversible" id="reversible_box" title="Use reversible layers (default {{ defaults.reversible }})"><br>
23
- Gradient Checkpointing: <input type="checkbox" name="use_checkpoint" id="checkpoint_box" checked title="Enable gradient checkpointing (default {{ defaults.use_checkpoint }})"><br>
24
- act_threshold: <input type="number" step="0.01" name="act_threshold" value="{{ defaults.act_threshold }}" title="ACT halt threshold (default {{ defaults.act_threshold }})"><br>
25
- c_floor: <input type="number" step="0.01" name="c_floor" value="{{ c_floor }}" title="Complexity floor"><br>
26
- s_floor: <input type="number" step="0.01" name="s_floor" value="{{ s_floor }}" title="Symbiosis floor"><br>
27
- <button type="submit">Init</button>
28
- </form>
29
- </section>
30
- <section>
31
- <h2>Train Step</h2>
32
- <form id="trainForm">
33
- Bits (e.g. 0 1 0 1): <input type="text" name="bits" value="0 1 0 1"><br>
34
- Upload file: <input type="file" id="train_file"><br>
35
- <button type="submit">Train</button>
36
- </form>
37
- <label>Load sample dataset:
38
- <select id="datasetSelect">
39
- <option value="">--Select--</option>
40
- <option value="wikitext2_train">Wikitext-2 (train)</option>
41
- <option value="wikitext2_validation">Wikitext-2 (validation)</option>
42
- </select>
43
- </label>
44
- <p id="trainOut"></p>
45
- </section>
46
- <section>
47
- <h2>Scale Up</h2>
48
- Width Mult: <input type="number" step="0.1" id="width_mult" value="1.0"><br>
49
- <button id="scaleBtn">Scale Model</button>
50
- </section>
51
- <section>
52
- <h2>Collapse Submodel</h2>
53
- <form id="collapseForm">
54
- Cluster Bits (JSON array of arrays):<br>
55
- <textarea name="clusters" rows="3" cols="40">[[0,1,0,1],[1,1,0,0]]</textarea><br>
56
- Target Params (JSON):<br>
57
- <textarea name="params" rows="3" cols="40">{"d_model":32,"nhead":4,"num_layers":1,"dim_feedforward":64,"max_seq_len":16}</textarea><br>
58
- Width Scale: <input type="number" step="0.1" id="width_scale" value="1.0"><br>
59
- <button type="submit">Collapse</button>
60
- </form>
61
- </section>
62
- <section>
63
- <h2>Inference</h2>
64
- <form id="inferForm">
65
- Bits: <input type="text" name="bits" value="0 1 0 1"><br>
66
- Upload file: <input type="file" id="infer_file"><br>
67
- <button type="submit">Infer</button>
68
- </form>
69
- <pre id="inferOut"></pre>
70
- </section>
71
- <section>
72
- <h2>Long Inference</h2>
73
- <form id="inferLongForm">
74
- Bits: <input type="text" name="bits" value="0 1 0 1"><br>
75
- ctx_bits: <input type="number" name="ctx_bits" value="4096"><br>
76
- overlap: <input type="number" name="overlap" value="256"><br>
77
- <button type="submit">Infer Long</button>
78
- </form>
79
- <pre id="inferLongOut"></pre>
80
- </section>
81
- <section>
82
- <h2>Text Inference</h2>
83
- <form id="textInferForm">
84
- Text: <input type="text" name="text" value="hello"><br>
85
- <button type="submit">Infer Text</button>
86
- </form>
87
- <pre id="textInferOut"></pre>
88
- </section>
89
- <section>
90
- <h2>&lambda; Weights</h2>
91
- <form id="lambdaForm">
92
- &lambda;<sub>K</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_K" oninput="lambda_K_val.innerText=value"><span id="lambda_K_val"></span><br>
93
- &lambda;<sub>C</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_C" oninput="lambda_C_val.innerText=value"><span id="lambda_C_val"></span><br>
94
- &lambda;<sub>S</sub>: <input type="range" min="0" max="2" step="0.1" id="lambda_S" oninput="lambda_S_val.innerText=value"><span id="lambda_S_val"></span><br>
95
- <button type="submit">Update</button>
96
- </form>
97
- </section>
98
- <section>
99
- <h2>Diffusion LM</h2>
100
- <label><input type="checkbox" id="diffusion_box"> Enable Diffusion Mode</label>
101
- </section>
102
- <section>
103
- <h2>GPU Acceleration</h2>
104
- <label><input type="checkbox" id="gpu_box"> Enable FSDP &amp; CUDA</label>
105
- </section>
106
- <section>
107
- <h2>Enable Compression</h2>
108
- <label><input type="checkbox" id="compression_box"> Compress I/O</label>
109
- <p>Ratio: <span id="comp_ratio">1.0</span></p>
110
- </section>
111
- <section>
112
- <h2>Quantization Aware Training</h2>
113
- <label><input type="checkbox" id="qat_box"> Enable 4-bit QAT</label>
114
- </section>
115
- <section>
116
- <h2>Model Status</h2>
117
- <pre id="statusOut"></pre>
118
- </section>
119
- <section>
120
- <h2>Telemetry</h2>
121
- <canvas id="metricChart" width="600" height="300"></canvas>
122
- </section>
123
- <section>
124
- <h2>Hugging Face Checkpoints</h2>
125
- Repo ID: <input type="text" id="hf_repo"><br>
126
- Token: <input type="password" id="hf_token" placeholder="optional"><br>
127
- <button id="uploadBtn">Upload weights</button>
128
- <button id="downloadBtn">Download weights</button>
129
- <p id="hfStatus"></p>
130
- </section>
131
-
132
- <script>
133
- async function postJSON(url, data){
134
- const resp = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(data)});
135
- return resp.json();
136
- }
137
-
138
- async function pollJob(id){
139
- while(true){
140
- const job = await fetch(`/job/${id}`).then(r=>r.json());
141
- if(job.status === 'completed') return job.result;
142
- if(job.status === 'error') throw job.error || 'Job failed';
143
- await new Promise(r=>setTimeout(r, 1000));
144
- }
145
- }
146
-
147
- function loadInitParams(){
148
- const saved = JSON.parse(localStorage.getItem('init_params')||'{}');
149
- const form = document.getElementById('initForm');
150
- for(const [k,v] of Object.entries(saved)){
151
- const el = form.elements[k];
152
- if(!el) continue;
153
- if(el.type === 'checkbox') el.checked = v; else el.value = v;
154
- }
155
- }
156
- loadInitParams();
157
-
158
- function byteArrayToBits(arr){
159
- const bits=[];
160
- for(const b of arr){
161
- for(let i=7;i>=0;i--) bits.push((b>>i)&1);
162
- }
163
- return bits;
164
- }
165
-
166
- let trainFileBits=null, inferFileBits=null, datasetBits=null;
167
-
168
- async function fileToBits(file){
169
- if(file.type.startsWith('text')){
170
- const text = await file.text();
171
- const res = await postJSON('/text_to_bits', {text});
172
- return res.bits;
173
- }
174
- const buf = await file.arrayBuffer();
175
- return byteArrayToBits(new Uint8Array(buf));
176
- }
177
-
178
- let metricChart;
179
- async function initChart(){
180
- const data = await fetch('/metrics').then(r=>r.json());
181
- const labels = data.negentropy.map((_,i)=>i);
182
- const ctx = document.getElementById('metricChart').getContext('2d');
183
- metricChart = new Chart(ctx, {
184
- type:'line',
185
- data:{
186
- labels:labels,
187
- datasets:[
188
- {label:'Negentropy', data:data.negentropy, borderColor:'blue', fill:false},
189
- {label:'LZ Complexity', data:data.lz_complexity, borderColor:'orange', fill:false},
190
- {label:'Symbiosis', data:data.symbiosis, borderColor:'green', fill:false}
191
- ]
192
- },
193
- options:{responsive:false, interaction:{mode:'index', intersect:false}}
194
- });
195
- }
196
-
197
- async function updateChart(){
198
- const data = await fetch('/metrics').then(r=>r.json());
199
- const labels = data.negentropy.map((_,i)=>i);
200
- metricChart.data.labels = labels;
201
- metricChart.data.datasets[0].data = data.negentropy;
202
- metricChart.data.datasets[1].data = data.lz_complexity;
203
- metricChart.data.datasets[2].data = data.symbiosis;
204
- metricChart.update();
205
- }
206
-
207
- initChart();
208
- setInterval(updateChart, 2000);
209
-
210
- async function refreshStatus(){
211
- const [s, c] = await Promise.all([fetch('/status'), fetch('/model_config')]);
212
- const status = await s.json();
213
- const config = await c.json();
214
- document.getElementById('statusOut').innerText = JSON.stringify({...status, ...config}, null, 2);
215
- }
216
-
217
- document.getElementById('initForm').addEventListener('submit', async (e)=>{
218
- e.preventDefault();
219
- const fd = new FormData(e.target);
220
- const obj = Object.fromEntries(fd.entries());
221
- const ints = ['d_model','nhead','num_layers','dim_feedforward','max_seq_len','chunk_size','overlap'];
222
- ints.forEach(k=>{ if(obj[k]===''){ delete obj[k]; } else obj[k]=parseInt(obj[k]); });
223
- obj.reversible = document.getElementById('reversible_box').checked;
224
- obj.use_checkpoint = document.getElementById('checkpoint_box').checked;
225
- obj.act_threshold = parseFloat(obj.act_threshold);
226
- const floors = {c_floor: parseFloat(obj.c_floor), s_floor: parseFloat(obj.s_floor)};
227
- delete obj.c_floor; delete obj.s_floor;
228
- await postJSON('/init', obj);
229
- await postJSON('/config/telemetry', floors);
230
- localStorage.setItem('init_params', JSON.stringify({...obj, ...floors}));
231
- refreshStatus();
232
- updateChart();
233
- });
234
-
235
- document.getElementById('trainForm').addEventListener('submit', async (e)=>{
236
- e.preventDefault();
237
- const form = e.target;
238
- let payload;
239
- if(trainFileBits){
240
- payload = trainFileBits;
241
- } else if(datasetBits){
242
- payload = datasetBits;
243
- } else {
244
- payload = [form.bits.value.trim().split(/\s+/).map(Number)];
245
- }
246
- for(const el of form.elements) el.disabled = true;
247
- const out = document.getElementById('trainOut');
248
- out.innerText = '⏳';
249
- try{
250
- const job = await postJSON('/train', {bits: payload});
251
- const res = await pollJob(job.job_id);
252
- out.innerText = 'Loss: '+res.loss.toFixed(4);
253
- if(res.ratio !== undefined){
254
- document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
255
- }
256
- } catch(err){
257
- out.innerText = 'Error';
258
- alert(err);
259
- } finally {
260
- for(const el of form.elements) el.disabled = false;
261
- refreshStatus();
262
- updateChart();
263
- }
264
- });
265
-
266
- document.getElementById('train_file').addEventListener('change', async (e)=>{
267
- const f = e.target.files[0];
268
- if(!f) return;
269
- const bits = await fileToBits(f);
270
- trainFileBits = [bits];
271
- datasetBits = null;
272
- document.querySelector('#trainForm input[name="bits"]').value = bits.slice(0,64).join(' ');
273
- });
274
-
275
- document.querySelector('#trainForm input[name="bits"]').addEventListener('input', ()=>{
276
- trainFileBits = null;
277
- datasetBits = null;
278
- });
279
-
280
- document.getElementById('scaleBtn').addEventListener('click', async ()=>{
281
- const btn = document.getElementById('scaleBtn');
282
- const input = document.getElementById('width_mult');
283
- const mult = parseFloat(input.value);
284
- btn.disabled = true; input.disabled = true;
285
- const original = btn.innerText; btn.innerText = '⏳';
286
- try{
287
- const job = await postJSON('/scale_up', {width_mult: mult});
288
- await pollJob(job.job_id);
289
- } catch(err){
290
- alert(err);
291
- } finally {
292
- btn.innerText = original;
293
- btn.disabled = false; input.disabled = false;
294
- refreshStatus();
295
- updateChart();
296
- }
297
- });
298
-
299
- document.getElementById('collapseForm').addEventListener('submit', async (e)=>{
300
- e.preventDefault();
301
- const form = e.target;
302
- const btn = form.querySelector('button');
303
- for(const el of form.elements) el.disabled = true;
304
- const clusters = JSON.parse(form.clusters.value);
305
- const params = JSON.parse(form.params.value);
306
- const w = parseFloat(document.getElementById('width_scale').value);
307
- const original = btn.innerText; btn.innerText = '⏳';
308
- try{
309
- const job = await postJSON('/collapse', {clusters: clusters, params: params, width_scale: w});
310
- await pollJob(job.job_id);
311
- } catch(err){
312
- alert(err);
313
- } finally {
314
- btn.innerText = original;
315
- for(const el of form.elements) el.disabled = false;
316
- refreshStatus();
317
- updateChart();
318
- }
319
- });
320
-
321
- document.getElementById('inferForm').addEventListener('submit', async (e)=>{
322
- e.preventDefault();
323
- let bits;
324
- if(inferFileBits){
325
- bits = inferFileBits;
326
- } else if(datasetBits){
327
- bits = [datasetBits[0]];
328
- } else {
329
- bits = [e.target.bits.value.trim().split(/\s+/).map(Number)];
330
- }
331
- const res = await postJSON('/infer', {bits});
332
- if(res.error){
333
- alert(res.error + '\n' + (res.suggestion||''));
334
- } else {
335
- document.getElementById('inferOut').innerText = JSON.stringify(res, null, 2);
336
- if(res.ratio !== undefined){
337
- document.getElementById('comp_ratio').innerText = res.ratio.toFixed(2);
338
- }
339
- }
340
- refreshStatus();
341
- updateChart();
342
- });
343
-
344
- document.getElementById('infer_file').addEventListener('change', async (e)=>{
345
- const f = e.target.files[0];
346
- if(!f) return;
347
- const bits = await fileToBits(f);
348
- inferFileBits = [bits];
349
- datasetBits = null;
350
- document.querySelector('#inferForm input[name="bits"]').value = bits.slice(0,64).join(' ');
351
- });
352
-
353
- document.querySelector('#inferForm input[name="bits"]').addEventListener('input', ()=>{
354
- inferFileBits = null;
355
- datasetBits = null;
356
- });
357
-
358
- document.getElementById('datasetSelect').addEventListener('change', async (e)=>{
359
- const val = e.target.value;
360
- trainFileBits = null;
361
- inferFileBits = null;
362
- if(!val){ datasetBits = null; return; }
363
- const [name, split] = val.split('_');
364
- const resp = await fetch(`/dataset?name=${name}&split=${split}&size=4&seq_len=64`);
365
- const data = await resp.json();
366
- datasetBits = data.bits;
367
- const preview = data.bits[0].slice(0,64).join(' ');
368
- document.querySelector('#trainForm input[name="bits"]').value = preview;
369
- document.querySelector('#inferForm input[name="bits"]').value = preview;
370
- });
371
-
372
- document.getElementById('inferLongForm').addEventListener('submit', async (e)=>{
373
- e.preventDefault();
374
- const bits = e.target.bits.value.trim().split(/\s+/).map(Number);
375
- const ctx = parseInt(e.target.ctx_bits.value);
376
- const ov = parseInt(e.target.overlap.value);
377
- const res = await postJSON('/infer_long', {bits: bits, ctx_bits: ctx, overlap: ov});
378
- document.getElementById('inferLongOut').innerText = JSON.stringify(res, null, 2);
379
- refreshStatus();
380
- updateChart();
381
- });
382
-
383
- document.getElementById('textInferForm').addEventListener('submit', async (e)=>{
384
- e.preventDefault();
385
- const text = e.target.text.value;
386
- const res = await postJSON('/infer_text', {text:text});
387
- document.getElementById('textInferOut').innerText = JSON.stringify(res, null, 2);
388
- refreshStatus();
389
- updateChart();
390
- });
391
-
392
- async function loadLambdas(){
393
- const resp = await fetch('/lambdas');
394
- const vals = await resp.json();
395
- for(const k of ['lambda_K','lambda_C','lambda_S']){
396
- document.getElementById(k).value = vals[k];
397
- document.getElementById(k+"_val").innerText = vals[k];
398
- }
399
- }
400
-
401
- document.getElementById('lambdaForm').addEventListener('submit', async (e)=>{
402
- e.preventDefault();
403
- const data = {
404
- lambda_K: parseFloat(document.getElementById('lambda_K').value),
405
- lambda_C: parseFloat(document.getElementById('lambda_C').value),
406
- lambda_S: parseFloat(document.getElementById('lambda_S').value),
407
- };
408
- await postJSON('/lambdas', data);
409
- for(const k in data){
410
- document.getElementById(k+"_val").innerText = data[k];
411
- }
412
- refreshStatus();
413
- });
414
-
415
- loadLambdas();
416
-
417
- function restoreToggle(id,key,endpoint,field){
418
- const box = document.getElementById(id);
419
- const saved = localStorage.getItem(key);
420
- if(saved !== null){ box.checked = saved === 'true'; postJSON(endpoint,{[field]: box.checked}); }
421
- box.addEventListener('change', async (e)=>{
422
- await postJSON(endpoint, {[field]: e.target.checked});
423
- localStorage.setItem(key, e.target.checked);
424
- refreshStatus();
425
- });
426
- }
427
-
428
- restoreToggle('diffusion_box','diffusion','/diffusion','diffusion');
429
- restoreToggle('gpu_box','use_gpu','/gpu','use_gpu');
430
- restoreToggle('compression_box','compression','/compression','compression');
431
- restoreToggle('qat_box','qat','/qat','qat');
432
-
433
- document.getElementById('uploadBtn').addEventListener('click', async ()=>{
434
- const repo = document.getElementById('hf_repo').value;
435
- const token = document.getElementById('hf_token').value;
436
- const res = await postJSON('/save_checkpoint', {repo_id: repo, token: token||undefined});
437
- document.getElementById('hfStatus').innerText = res.status || res.error;
438
- });
439
-
440
- document.getElementById('downloadBtn').addEventListener('click', async ()=>{
441
- const repo = document.getElementById('hf_repo').value;
442
- const token = document.getElementById('hf_token').value;
443
- const res = await postJSON('/download_checkpoint', {repo_id: repo, token: token||undefined});
444
- document.getElementById('hfStatus').innerText = res.status || res.error;
445
- refreshStatus();
446
- updateChart();
447
- });
448
-
449
- refreshStatus();
450
- </script>
451
- </div>
452
- </body>
453
- </html>
454
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build_full_bits.py DELETED
@@ -1,23 +0,0 @@
1
- import pathlib
2
- import torch
3
- from datasets import load_dataset
4
-
5
- TXT_MB = 100
6
- OUT = pathlib.Path('full_bits.pt')
7
-
8
-
9
- def build_bits(out: pathlib.Path = OUT, txt_mb: int = TXT_MB) -> None:
10
- ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
11
- buf = bytearray()
12
- for line in ds['text']:
13
- buf.extend(line.encode() + b"\n")
14
- if len(buf) >= txt_mb * 2 ** 20:
15
- break
16
- bits = []
17
- for byte in buf:
18
- bits.extend(int(b) for b in f'{byte:08b}')
19
- tensor = torch.tensor(bits, dtype=torch.uint8)
20
- torch.save(tensor, out)
21
-
22
- if __name__ == '__main__':
23
- build_bits()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cpu_edge_training.py DELETED
@@ -1,468 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- CPU-Optimized Edge Deployment BitTransformerLM Training
4
- Optimized for consumer devices and edge applications.
5
- """
6
-
7
- import os
8
- import time
9
- import torch
10
- import torch.nn.functional as F
11
- from datasets import load_dataset
12
-
13
- from bit_transformer import (
14
- BitTransformerLM,
15
- text_to_bits,
16
- bits_to_text,
17
- train_loop,
18
- configure_optimizer,
19
- save_model,
20
- load_model,
21
- set_dropout,
22
- hil_safe_inference,
23
- quantize_dynamic,
24
- )
25
- from bit_transformer.torch_utils import cpu_autocast
26
- from bit_transformer.training import train_loop
27
-
28
-
29
- def create_optimal_cpu_model():
30
- """Create BitTransformerLM optimized for CPU edge deployment."""
31
- print("🧠 Creating CPU-optimized BitTransformerLM...")
32
-
33
- # Optimal configuration for edge devices:
34
- # - Small model size for low memory footprint
35
- # - CPU autocast for faster FP16 inference
36
- # - No reversible layers (simpler for CPU)
37
- # - Gradient checkpointing disabled for speed
38
- # - Small context length for efficiency
39
-
40
- model = BitTransformerLM(
41
- d_model=64, # Small embedding dimension (vs 128 default)
42
- nhead=4, # Fewer attention heads (vs 8 default)
43
- num_layers=3, # Shallow model (vs 4 default)
44
- dim_feedforward=128, # Smaller FFN (vs 512 default)
45
- max_seq_len=256, # Shorter context (vs 1024 default)
46
- reversible=False, # Disable reversible layers (CPU doesn't benefit much)
47
- use_checkpoint=False, # Disable gradient checkpointing (prioritize speed)
48
- use_autocast=True, # Enable CPU autocast for BF16 mixed precision
49
- use_act=False, # Disable ACT for simplicity
50
- chunk_size=32, # Small chunks for memory efficiency
51
- full_attn_logging=False, # Disable attention logging to save memory
52
- lambda_K=1.0, # Standard telemetry weights
53
- lambda_C=1.0,
54
- lambda_S=1.0,
55
- )
56
-
57
- # Calculate model parameters
58
- total_params = sum(p.numel() for p in model.parameters())
59
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
60
-
61
- print(f" 📊 Model Configuration:")
62
- print(f" d_model: {64}")
63
- print(f" num_layers: {3}")
64
- print(f" nhead: {4}")
65
- print(f" dim_feedforward: {128}")
66
- print(f" max_seq_len: {256}")
67
- print(f" Total parameters: {total_params:,}")
68
- print(f" Trainable parameters: {trainable_params:,}")
69
- print(f" Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)")
70
- print(f" With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)")
71
-
72
- return model
73
-
74
-
75
- def load_training_dataset(dataset_size=512, max_len=128):
76
- """Load and prepare training dataset optimized for edge training."""
77
- print("📚 Loading training dataset...")
78
-
79
- try:
80
- # Try to load BitTransformerLM dataset from HuggingFace
81
- print(" Attempting to load BitTransformerLM dataset...")
82
- dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size))
83
- if dataset and len(dataset) > 0:
84
- train_texts = [item['text'] for item in dataset if item.get('text')]
85
- if len(train_texts) > 0:
86
- print(f" ✅ Loaded {len(train_texts)} samples from BitTransformerLM dataset")
87
- else:
88
- raise Exception("No text samples found in dataset")
89
- else:
90
- raise Exception("Dataset empty or not accessible")
91
-
92
- except Exception as e:
93
- print(f" ⚠️ BitTransformerLM dataset not available: {e}")
94
- print(" 📖 Falling back to WikiText-2...")
95
- try:
96
- # Fallback to WikiText-2 for training
97
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
98
- train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size]
99
- print(f" ✅ Loaded {len(train_texts)} samples from WikiText-2")
100
- except Exception as e2:
101
- print(f" ❌ Failed to load WikiText-2: {e2}")
102
- print(" 🎲 Using synthetic text data...")
103
- # Generate simple synthetic text for demonstration
104
- synthetic_texts = [
105
- "The quick brown fox jumps over the lazy dog.",
106
- "Machine learning is transforming technology.",
107
- "Edge computing enables local AI processing.",
108
- "BitTransformerLM uses bit-native processing.",
109
- "CPU optimization improves inference speed.",
110
- "Neural networks learn from training data.",
111
- "Transformers use attention mechanisms.",
112
- "Language models understand text patterns.",
113
- ]
114
- train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size]
115
- print(f" ✅ Generated {len(train_texts)} synthetic samples")
116
-
117
- # Convert text to bits
118
- print(" 🔄 Converting text to bits...")
119
- train_sequences = []
120
- valid_sequences = []
121
-
122
- for i, text in enumerate(train_texts):
123
- try:
124
- bits = text_to_bits(text)[:max_len]
125
- if len(bits) < max_len:
126
- bits.extend([0] * (max_len - len(bits))) # Pad to max_len
127
-
128
- # Use 80/20 split for train/validation
129
- if i < len(train_texts) * 0.8:
130
- train_sequences.append(bits)
131
- else:
132
- valid_sequences.append(bits)
133
-
134
- except Exception as e:
135
- print(f" ⚠️ Failed to convert text to bits: {e}")
136
- continue
137
-
138
- train_tensor = torch.tensor(train_sequences, dtype=torch.long)
139
- valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16]
140
-
141
- print(f" 📊 Dataset Statistics:")
142
- print(f" Training sequences: {len(train_sequences)}")
143
- print(f" Validation sequences: {len(valid_sequences)}")
144
- print(f" Sequence length: {max_len}")
145
- print(f" Training tensor shape: {train_tensor.shape}")
146
-
147
- return train_tensor, valid_tensor, train_texts[:len(train_sequences)]
148
-
149
-
150
- def train_cpu_optimized_model(model, train_data, valid_data, epochs=5):
151
- """Train the model with CPU-optimized settings."""
152
- print(f"🚀 Training CPU-optimized BitTransformerLM for {epochs} epochs...")
153
-
154
- # Set model to training mode
155
- model.train()
156
- set_dropout(model, 0.1)
157
-
158
- # Configure optimizer for edge deployment
159
- # Lower learning rate and smaller batch size for stable CPU training
160
- batch_size = 4 # Small batch size for memory efficiency
161
- learning_rate = 5e-4 # Conservative learning rate
162
- total_steps = max(1, epochs * (len(train_data) // batch_size)) # Ensure at least 1 step
163
-
164
- if len(train_data) == 0:
165
- raise ValueError("No training data available - check dataset loading")
166
-
167
- optimizer, scheduler = configure_optimizer(
168
- model,
169
- lr=learning_rate,
170
- total_steps=total_steps,
171
- weight_decay=0.01
172
- )
173
-
174
- print(f" 📋 Training Configuration:")
175
- print(f" Batch size: {batch_size}")
176
- print(f" Learning rate: {learning_rate}")
177
- print(f" Total steps: {total_steps}")
178
- print(f" CPU autocast: Enabled")
179
-
180
- # Training loop with CPU optimizations
181
- train_losses = []
182
-
183
- for epoch in range(epochs):
184
- print(f"\n 📖 Epoch {epoch + 1}/{epochs}")
185
- epoch_losses = []
186
- epoch_start_time = time.time()
187
-
188
- # Shuffle training data
189
- perm = torch.randperm(len(train_data))
190
- train_data_shuffled = train_data[perm]
191
-
192
- # Process in small batches
193
- for batch_idx in range(0, len(train_data_shuffled), batch_size):
194
- batch_end = min(batch_idx + batch_size, len(train_data_shuffled))
195
- batch = train_data_shuffled[batch_idx:batch_end]
196
-
197
- if len(batch) == 0:
198
- continue
199
-
200
- optimizer.zero_grad()
201
-
202
- # Use CPU autocast for mixed precision
203
- with cpu_autocast():
204
- logits, telemetry = model(batch)
205
-
206
- # Standard autoregressive loss
207
- pred = logits[:, :-1, :].reshape(-1, 2)
208
- target = batch[:, 1:].reshape(-1)
209
- loss = F.cross_entropy(pred, target)
210
-
211
- # Backward pass
212
- loss.backward()
213
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
214
- optimizer.step()
215
-
216
- # Only step scheduler if we haven't exceeded total steps
217
- if scheduler.last_epoch < scheduler.total_steps - 1:
218
- scheduler.step()
219
-
220
- batch_loss = loss.item()
221
- epoch_losses.append(batch_loss)
222
-
223
- # Log progress every 50 steps
224
- if (batch_idx // batch_size) % 50 == 0:
225
- avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses))
226
- telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}"
227
- print(f" Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}")
228
-
229
- epoch_time = time.time() - epoch_start_time
230
- avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
231
- train_losses.append(avg_epoch_loss)
232
-
233
- print(f" ⏱️ Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}")
234
-
235
- # Validation every epoch
236
- if len(valid_data) > 0:
237
- model.eval()
238
- set_dropout(model, 0.0)
239
-
240
- with torch.no_grad():
241
- with cpu_autocast():
242
- val_batch = valid_data[:min(8, len(valid_data))] # Small validation batch
243
- val_logits, val_telemetry = model(val_batch)
244
- val_pred = val_logits[:, :-1, :].reshape(-1, 2)
245
- val_target = val_batch[:, 1:].reshape(-1)
246
- val_loss = F.cross_entropy(val_pred, val_target).item()
247
-
248
- print(f" 📊 Validation Loss: {val_loss:.4f}")
249
- print(f" 📈 Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}")
250
-
251
- model.train()
252
- set_dropout(model, 0.1)
253
-
254
- print(f"\n✅ Training completed!")
255
- print(f" Final training loss: {train_losses[-1]:.4f}")
256
-
257
- return model, train_losses
258
-
259
-
260
- def test_model_inference(model, test_texts):
261
- """Test the trained model with inference and safety checks."""
262
- print("\n🧪 Testing Model Inference...")
263
-
264
- model.eval()
265
- set_dropout(model, 0.0)
266
-
267
- # Test basic inference
268
- test_samples = test_texts[:3] # Test with first 3 samples
269
-
270
- for i, text in enumerate(test_samples):
271
- print(f"\n Test {i + 1}: {text[:50]}...")
272
-
273
- try:
274
- # Convert to bits
275
- input_bits = text_to_bits(text)[:64] # Shorter for demo
276
- if len(input_bits) < 64:
277
- input_bits.extend([0] * (64 - len(input_bits)))
278
-
279
- input_tensor = torch.tensor([input_bits], dtype=torch.long)
280
-
281
- # Run inference with CPU autocast
282
- with torch.no_grad():
283
- with cpu_autocast():
284
- logits, telemetry = model(input_tensor)
285
-
286
- # Generate next tokens
287
- next_token_logits = logits[0, -1, :]
288
- next_token_probs = F.softmax(next_token_logits, dim=-1)
289
- next_token = torch.multinomial(next_token_probs, 1).item()
290
-
291
- print(f" Input bits: {input_bits[:16]}... (showing first 16)")
292
- print(f" Next token prediction: {next_token}")
293
- print(f" Next token confidence: {next_token_probs[next_token]:.3f}")
294
- print(f" Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}")
295
-
296
- except Exception as e:
297
- print(f" ❌ Inference failed: {e}")
298
-
299
- # Test safe inference
300
- print(f"\n🛡️ Testing Safe Inference...")
301
- try:
302
- # Create a simple prompt
303
- test_prompt = "The future of AI is"
304
- prompt_bits = text_to_bits(test_prompt)
305
- prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long)
306
-
307
- with cpu_autocast():
308
- safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16)
309
-
310
- if safe_result is not None:
311
- print(f" ✅ Safe inference successful")
312
- print(f" Generated {len(safe_result[0]) - len(prompt_bits)} new tokens")
313
- else:
314
- print(f" ⚠️ Safe inference blocked by safety gates")
315
-
316
- except Exception as e:
317
- print(f" ❌ Safe inference test failed: {e}")
318
-
319
-
320
- def benchmark_cpu_performance(model):
321
- """Benchmark the model's CPU performance."""
322
- print("\n⚡ CPU Performance Benchmark...")
323
-
324
- model.eval()
325
- set_dropout(model, 0.0)
326
-
327
- # Prepare test data
328
- batch_sizes = [1, 2, 4]
329
- sequence_lengths = [32, 64, 128]
330
-
331
- results = []
332
-
333
- for batch_size in batch_sizes:
334
- for seq_len in sequence_lengths:
335
- print(f"\n Testing batch_size={batch_size}, seq_len={seq_len}")
336
-
337
- # Create random test data
338
- test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long)
339
-
340
- # Warmup
341
- with torch.no_grad():
342
- with cpu_autocast():
343
- for _ in range(3):
344
- _, _ = model(test_data)
345
-
346
- # Benchmark
347
- times = []
348
- for _ in range(10):
349
- start_time = time.time()
350
- with torch.no_grad():
351
- with cpu_autocast():
352
- logits, telemetry = model(test_data)
353
- end_time = time.time()
354
- times.append(end_time - start_time)
355
-
356
- avg_time = sum(times) / len(times)
357
- throughput = (batch_size * seq_len) / avg_time
358
-
359
- result = {
360
- 'batch_size': batch_size,
361
- 'seq_len': seq_len,
362
- 'avg_time_ms': avg_time * 1000,
363
- 'throughput_tokens_per_sec': throughput
364
- }
365
- results.append(result)
366
-
367
- print(f" Average time: {avg_time * 1000:.2f}ms")
368
- print(f" Throughput: {throughput:.0f} tokens/sec")
369
-
370
- # Summary
371
- print(f"\n📊 Performance Summary:")
372
- best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec'])
373
- print(f" Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
374
- print(f" At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}")
375
-
376
- return results
377
-
378
-
379
- def quantize_for_deployment(model):
380
- """Apply dynamic quantization for deployment."""
381
- print("\n🗜️ Applying Dynamic Quantization for Deployment...")
382
-
383
- try:
384
- quantized_model = quantize_dynamic(model)
385
-
386
- # Compare model sizes
387
- original_params = sum(p.numel() for p in model.parameters())
388
- quantized_params = sum(p.numel() for p in quantized_model.parameters())
389
-
390
- print(f" Original parameters: {original_params:,}")
391
- print(f" Quantized parameters: {quantized_params:,}")
392
- print(f" Model size reduction: ~50% (FP32 -> INT8)")
393
-
394
- # Quick inference test
395
- test_input = torch.randint(0, 2, (1, 32), dtype=torch.long)
396
-
397
- with torch.no_grad():
398
- original_output = model(test_input)
399
- quantized_output = quantized_model(test_input)
400
-
401
- print(f" ✅ Quantization successful - model still functional")
402
-
403
- return quantized_model
404
-
405
- except Exception as e:
406
- print(f" ❌ Quantization failed: {e}")
407
- return model
408
-
409
-
410
- def main():
411
- """Main training and testing pipeline."""
412
- print("🚀 CPU-Optimized BitTransformerLM Training Pipeline")
413
- print("="*60)
414
-
415
- # Step 1: Create optimal CPU model
416
- model = create_optimal_cpu_model()
417
-
418
- # Step 2: Load training dataset
419
- train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128)
420
-
421
- # Step 3: Train the model
422
- trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3)
423
-
424
- # Step 4: Test inference
425
- test_model_inference(trained_model, train_texts)
426
-
427
- # Step 5: Benchmark performance
428
- benchmark_results = benchmark_cpu_performance(trained_model)
429
-
430
- # Step 6: Apply quantization
431
- quantized_model = quantize_for_deployment(trained_model)
432
-
433
- # Step 7: Save models
434
- print("\n💾 Saving Models...")
435
-
436
- # Create weights directory if it doesn't exist
437
- os.makedirs("weights", exist_ok=True)
438
-
439
- try:
440
- save_model(trained_model, "weights/cpu_edge_model.pt.gz")
441
- print(" ✅ Saved trained model: weights/cpu_edge_model.pt.gz")
442
-
443
- save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz")
444
- print(" ✅ Saved quantized model: weights/cpu_edge_model_quantized.pt.gz")
445
-
446
- except Exception as e:
447
- print(f" ⚠️ Model saving failed: {e}")
448
-
449
- # Final summary
450
- print("\n" + "="*60)
451
- print("🎉 CPU-Optimized BitTransformerLM Training Complete!")
452
- print("="*60)
453
-
454
- total_params = sum(p.numel() for p in trained_model.parameters())
455
- final_loss = train_losses[-1] if train_losses else "N/A"
456
- best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec'])
457
-
458
- print(f"📊 Final Results:")
459
- print(f" Model Parameters: {total_params:,}")
460
- print(f" Final Training Loss: {final_loss}")
461
- print(f" Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
462
- print(f" Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB")
463
- print(f" CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks")
464
- print(f" Edge Ready: ✅ Optimized for consumer CPUs")
465
-
466
-
467
- if __name__ == "__main__":
468
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
create_dataset.py DELETED
@@ -1,61 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- BitTransformerLM Dataset Creation Script
4
-
5
- Usage:
6
- python create_dataset.py --token YOUR_HF_TOKEN --repo-id YOUR_REPO_NAME
7
-
8
- This script creates a comprehensive dataset for BitTransformerLM training
9
- and uploads it to HuggingFace Hub with proper metadata and organization.
10
- """
11
-
12
- import argparse
13
- import sys
14
- from pathlib import Path
15
-
16
- # Add the bit_transformer module to path
17
- sys.path.insert(0, str(Path(__file__).parent))
18
-
19
- from bit_transformer.dataset_builder import create_bittransformerlm_dataset
20
-
21
-
22
- def main():
23
- parser = argparse.ArgumentParser(description="Create BitTransformerLM Dataset")
24
- parser.add_argument("--token", required=True, help="HuggingFace access token")
25
- parser.add_argument("--repo-id", default="BitTransformerLM", help="Dataset repository ID")
26
- parser.add_argument("--private", action="store_true", default=True, help="Make dataset private")
27
- parser.add_argument("--samples", type=int, default=25000, help="Total number of samples")
28
-
29
- args = parser.parse_args()
30
-
31
- print("🚀 Starting BitTransformerLM Dataset Creation")
32
- print(f"Repository: {args.repo_id}")
33
- print(f"Private: {args.private}")
34
- print(f"Target samples: {args.samples}")
35
- print("-" * 50)
36
-
37
- try:
38
- dataset_url = create_bittransformerlm_dataset(
39
- hf_token=args.token,
40
- repo_id=args.repo_id
41
- )
42
-
43
- print("\n" + "=" * 50)
44
- print("🎉 SUCCESS! Dataset created and uploaded")
45
- print(f"📍 URL: {dataset_url}")
46
- print("=" * 50)
47
-
48
- print("\n📋 Next Steps:")
49
- print("1. View your dataset on HuggingFace Hub")
50
- print("2. Test loading with: `from datasets import load_dataset`")
51
- print("3. Integrate with BitTransformerLM training pipeline")
52
- print("4. Monitor dataset usage and performance metrics")
53
-
54
- except Exception as e:
55
- print(f"\n❌ ERROR: {e}")
56
- print("Please check your token and repository permissions.")
57
- sys.exit(1)
58
-
59
-
60
- if __name__ == "__main__":
61
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enhanced_checkpoint_system.py DELETED
@@ -1,374 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Enhanced checkpointing system for BitTransformerLM with multiple training runs support.
4
- Optimized for Claude Code environment with HF Pro + 20GB persistent storage.
5
- """
6
-
7
- import os
8
- import json
9
- import shutil
10
- import logging
11
- from pathlib import Path
12
- from typing import Dict, Any, Optional, List, Union
13
- from datetime import datetime
14
- import torch
15
- from huggingface_hub import HfApi, hf_hub_download
16
-
17
- from bit_transformer.error_handling import with_error_recovery, safe_operation
18
- from bit_transformer.types import PathLike, ModelConfig, TrainingConfig
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- class EnhancedCheckpointManager:
24
- """Advanced checkpoint management for multiple training runs with HF integration."""
25
-
26
- def __init__(self,
27
- base_dir: PathLike = "/data/checkpoints",
28
- hf_repo_id: str = "WCNegentropy/BitTransformerLM",
29
- hf_token: Optional[str] = None,
30
- max_local_checkpoints: int = 5):
31
-
32
- self.base_dir = Path(base_dir)
33
- self.base_dir.mkdir(parents=True, exist_ok=True)
34
-
35
- self.hf_repo_id = hf_repo_id
36
- self.hf_token = hf_token or os.getenv("HF_TOKEN")
37
- self.api = HfApi(token=self.hf_token) if self.hf_token else None
38
-
39
- self.max_local_checkpoints = max_local_checkpoints
40
-
41
- # Training session tracking
42
- self.sessions_dir = self.base_dir / "training_sessions"
43
- self.sessions_dir.mkdir(exist_ok=True)
44
-
45
- # Best models storage
46
- self.best_models_dir = self.base_dir / "best_models"
47
- self.best_models_dir.mkdir(exist_ok=True)
48
-
49
- def create_training_session(self,
50
- session_name: str,
51
- model_config: ModelConfig,
52
- training_config: TrainingConfig) -> str:
53
- """Create a new training session with metadata."""
54
-
55
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
56
- session_id = f"{session_name}_{timestamp}"
57
- session_dir = self.sessions_dir / session_id
58
- session_dir.mkdir(exist_ok=True)
59
-
60
- # Save session metadata
61
- metadata = {
62
- "session_id": session_id,
63
- "session_name": session_name,
64
- "created_at": timestamp,
65
- "model_config": model_config,
66
- "training_config": training_config,
67
- "checkpoints": [],
68
- "best_metric": None,
69
- "status": "active"
70
- }
71
-
72
- with open(session_dir / "metadata.json", "w") as f:
73
- json.dump(metadata, f, indent=2, default=str)
74
-
75
- logger.info(f"Created training session: {session_id}")
76
- return session_id
77
-
78
- @with_error_recovery(recovery_value=False)
79
- def save_checkpoint(self,
80
- model: torch.nn.Module,
81
- session_id: str,
82
- epoch: int,
83
- metrics: Dict[str, float],
84
- optimizer_state: Optional[Dict] = None,
85
- scheduler_state: Optional[Dict] = None,
86
- additional_data: Optional[Dict] = None) -> bool:
87
- """Save checkpoint with comprehensive metadata."""
88
-
89
- session_dir = self.sessions_dir / session_id
90
- if not session_dir.exists():
91
- raise ValueError(f"Training session {session_id} not found")
92
-
93
- # Create checkpoint directory
94
- checkpoint_name = f"checkpoint_epoch_{epoch:04d}"
95
- checkpoint_dir = session_dir / checkpoint_name
96
- checkpoint_dir.mkdir(exist_ok=True)
97
-
98
- # Save model state
99
- model_path = checkpoint_dir / "model.pt"
100
- torch.save({
101
- 'model_state_dict': model.state_dict(),
102
- 'epoch': epoch,
103
- 'metrics': metrics,
104
- 'model_config': getattr(model, 'config', {}),
105
- 'timestamp': datetime.now().isoformat()
106
- }, model_path)
107
-
108
- # Save optimizer state if provided
109
- if optimizer_state:
110
- torch.save(optimizer_state, checkpoint_dir / "optimizer.pt")
111
-
112
- # Save scheduler state if provided
113
- if scheduler_state:
114
- torch.save(scheduler_state, checkpoint_dir / "scheduler.pt")
115
-
116
- # Save additional data
117
- if additional_data:
118
- with open(checkpoint_dir / "additional_data.json", "w") as f:
119
- json.dump(additional_data, f, indent=2, default=str)
120
-
121
- # Update session metadata
122
- self._update_session_metadata(session_id, checkpoint_name, metrics)
123
-
124
- # Cleanup old checkpoints to save space
125
- self._cleanup_old_checkpoints(session_dir)
126
-
127
- logger.info(f"Saved checkpoint {checkpoint_name} for session {session_id}")
128
- return True
129
-
130
- def load_checkpoint(self,
131
- session_id: str,
132
- checkpoint_name: Optional[str] = None,
133
- model: Optional[torch.nn.Module] = None) -> Dict[str, Any]:
134
- """Load checkpoint with all associated data."""
135
-
136
- session_dir = self.sessions_dir / session_id
137
- if not session_dir.exists():
138
- raise ValueError(f"Training session {session_id} not found")
139
-
140
- # Use latest checkpoint if none specified
141
- if checkpoint_name is None:
142
- checkpoints = [d for d in session_dir.iterdir()
143
- if d.is_dir() and d.name.startswith("checkpoint_")]
144
- if not checkpoints:
145
- raise ValueError(f"No checkpoints found for session {session_id}")
146
- checkpoint_name = max(checkpoints, key=lambda x: x.name).name
147
-
148
- checkpoint_dir = session_dir / checkpoint_name
149
- if not checkpoint_dir.exists():
150
- raise ValueError(f"Checkpoint {checkpoint_name} not found in session {session_id}")
151
-
152
- # Load model state
153
- model_path = checkpoint_dir / "model.pt"
154
- checkpoint_data = torch.load(model_path, map_location='cpu', weights_only=False)
155
-
156
- if model is not None:
157
- model.load_state_dict(checkpoint_data['model_state_dict'])
158
-
159
- # Load optimizer state if exists
160
- optimizer_state = None
161
- optimizer_path = checkpoint_dir / "optimizer.pt"
162
- if optimizer_path.exists():
163
- optimizer_state = torch.load(optimizer_path, map_location='cpu', weights_only=False)
164
-
165
- # Load scheduler state if exists
166
- scheduler_state = None
167
- scheduler_path = checkpoint_dir / "scheduler.pt"
168
- if scheduler_path.exists():
169
- scheduler_state = torch.load(scheduler_path, map_location='cpu', weights_only=False)
170
-
171
- # Load additional data if exists
172
- additional_data = {}
173
- additional_path = checkpoint_dir / "additional_data.json"
174
- if additional_path.exists():
175
- with open(additional_path) as f:
176
- additional_data = json.load(f)
177
-
178
- return {
179
- 'model_data': checkpoint_data,
180
- 'optimizer_state': optimizer_state,
181
- 'scheduler_state': scheduler_state,
182
- 'additional_data': additional_data,
183
- 'checkpoint_path': str(checkpoint_dir)
184
- }
185
-
186
- def save_best_model(self,
187
- session_id: str,
188
- model: torch.nn.Module,
189
- metric_name: str,
190
- metric_value: float,
191
- is_better_func: callable = lambda x, y: x > y) -> bool:
192
- """Save model if it achieves best performance."""
193
-
194
- best_model_path = self.best_models_dir / f"{session_id}_best.pt"
195
- best_meta_path = self.best_models_dir / f"{session_id}_best_meta.json"
196
-
197
- # Check if this is the best model so far
198
- current_best = None
199
- if best_meta_path.exists():
200
- with open(best_meta_path) as f:
201
- current_best = json.load(f)
202
-
203
- if current_best is None or is_better_func(metric_value, current_best['metric_value']):
204
- # Save new best model
205
- torch.save({
206
- 'model_state_dict': model.state_dict(),
207
- 'metric_name': metric_name,
208
- 'metric_value': metric_value,
209
- 'session_id': session_id,
210
- 'timestamp': datetime.now().isoformat()
211
- }, best_model_path)
212
-
213
- # Save metadata
214
- with open(best_meta_path, "w") as f:
215
- json.dump({
216
- 'metric_name': metric_name,
217
- 'metric_value': metric_value,
218
- 'session_id': session_id,
219
- 'timestamp': datetime.now().isoformat()
220
- }, f, indent=2)
221
-
222
- logger.info(f"New best model saved for session {session_id}: {metric_name}={metric_value}")
223
- return True
224
-
225
- return False
226
-
227
- def push_to_hf(self,
228
- session_id: str,
229
- checkpoint_name: Optional[str] = None,
230
- include_optimizer: bool = False) -> bool:
231
- """Push checkpoint to HuggingFace Hub."""
232
-
233
- if not self.api:
234
- logger.error("HuggingFace API not available - check token")
235
- return False
236
-
237
- try:
238
- checkpoint_data = self.load_checkpoint(session_id, checkpoint_name)
239
- checkpoint_dir = Path(checkpoint_data['checkpoint_path'])
240
-
241
- # Upload model weights
242
- self.api.upload_file(
243
- path_or_fileobj=str(checkpoint_dir / "model.pt"),
244
- path_in_repo=f"checkpoints/{session_id}/model.pt",
245
- repo_id=self.hf_repo_id,
246
- commit_message=f"Upload checkpoint {checkpoint_name or 'latest'} from session {session_id}"
247
- )
248
-
249
- # Upload optimizer state if requested and exists
250
- if include_optimizer and (checkpoint_dir / "optimizer.pt").exists():
251
- self.api.upload_file(
252
- path_or_fileobj=str(checkpoint_dir / "optimizer.pt"),
253
- path_in_repo=f"checkpoints/{session_id}/optimizer.pt",
254
- repo_id=self.hf_repo_id
255
- )
256
-
257
- logger.info(f"Successfully pushed checkpoint to HuggingFace: {self.hf_repo_id}")
258
- return True
259
-
260
- except Exception as e:
261
- logger.error(f"Failed to push to HuggingFace: {e}")
262
- return False
263
-
264
- def pull_from_hf(self,
265
- session_id: str,
266
- local_session_id: Optional[str] = None) -> bool:
267
- """Pull checkpoint from HuggingFace Hub."""
268
-
269
- if not self.api:
270
- logger.error("HuggingFace API not available - check token")
271
- return False
272
-
273
- try:
274
- local_session = local_session_id or session_id
275
- local_dir = self.sessions_dir / local_session / "checkpoint_from_hf"
276
- local_dir.mkdir(parents=True, exist_ok=True)
277
-
278
- # Download model weights
279
- model_file = hf_hub_download(
280
- repo_id=self.hf_repo_id,
281
- filename=f"checkpoints/{session_id}/model.pt",
282
- local_dir=str(local_dir),
283
- local_dir_use_symlinks=False
284
- )
285
-
286
- logger.info(f"Successfully pulled checkpoint from HuggingFace to {local_dir}")
287
- return True
288
-
289
- except Exception as e:
290
- logger.error(f"Failed to pull from HuggingFace: {e}")
291
- return False
292
-
293
- def get_storage_usage(self) -> Dict[str, Any]:
294
- """Get detailed storage usage breakdown."""
295
-
296
- def get_dir_size(path: Path) -> int:
297
- total = 0
298
- for item in path.rglob('*'):
299
- if item.is_file():
300
- total += item.stat().st_size
301
- return total
302
-
303
- usage = {
304
- 'total_gb': get_dir_size(self.base_dir) / 1e9,
305
- 'sessions_gb': get_dir_size(self.sessions_dir) / 1e9,
306
- 'best_models_gb': get_dir_size(self.best_models_dir) / 1e9,
307
- 'num_sessions': len(list(self.sessions_dir.iterdir())),
308
- 'num_best_models': len(list(self.best_models_dir.glob('*_best.pt'))),
309
- }
310
-
311
- # Get per-session breakdown
312
- sessions = []
313
- for session_dir in self.sessions_dir.iterdir():
314
- if session_dir.is_dir():
315
- sessions.append({
316
- 'session_id': session_dir.name,
317
- 'size_gb': get_dir_size(session_dir) / 1e9,
318
- 'num_checkpoints': len(list(session_dir.glob('checkpoint_*')))
319
- })
320
-
321
- usage['sessions'] = sorted(sessions, key=lambda x: x['size_gb'], reverse=True)
322
-
323
- return usage
324
-
325
- def _update_session_metadata(self, session_id: str, checkpoint_name: str, metrics: Dict[str, float]):
326
- """Update session metadata with new checkpoint info."""
327
- metadata_path = self.sessions_dir / session_id / "metadata.json"
328
-
329
- with open(metadata_path) as f:
330
- metadata = json.load(f)
331
-
332
- metadata['checkpoints'].append({
333
- 'name': checkpoint_name,
334
- 'metrics': metrics,
335
- 'timestamp': datetime.now().isoformat()
336
- })
337
-
338
- # Update best metric if applicable
339
- if 'loss' in metrics:
340
- if metadata['best_metric'] is None or metrics['loss'] < metadata['best_metric'].get('loss', float('inf')):
341
- metadata['best_metric'] = metrics.copy()
342
-
343
- with open(metadata_path, "w") as f:
344
- json.dump(metadata, f, indent=2, default=str)
345
-
346
- def _cleanup_old_checkpoints(self, session_dir: Path):
347
- """Remove oldest checkpoints to stay within limits."""
348
- checkpoints = sorted([d for d in session_dir.iterdir()
349
- if d.is_dir() and d.name.startswith("checkpoint_")],
350
- key=lambda x: x.stat().st_mtime)
351
-
352
- while len(checkpoints) > self.max_local_checkpoints:
353
- old_checkpoint = checkpoints.pop(0)
354
- shutil.rmtree(old_checkpoint)
355
- logger.info(f"Cleaned up old checkpoint: {old_checkpoint.name}")
356
-
357
-
358
- # Convenience functions for easy usage
359
- def create_checkpoint_manager(hf_token: str = "os.environ.get('HF_TOKEN', 'your-token-here')") -> EnhancedCheckpointManager:
360
- """Create a pre-configured checkpoint manager for this environment."""
361
- return EnhancedCheckpointManager(
362
- base_dir="/data/checkpoints",
363
- hf_repo_id="WCNegentropy/BitTransformerLM",
364
- hf_token=hf_token,
365
- max_local_checkpoints=3 # Conservative for 20GB storage
366
- )
367
-
368
-
369
- if __name__ == "__main__":
370
- # Demo usage
371
- manager = create_checkpoint_manager()
372
- usage = manager.get_storage_usage()
373
- print(f"Current storage usage: {usage['total_gb']:.2f} GB")
374
- print(f"Number of training sessions: {usage['num_sessions']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example.py DELETED
@@ -1,6 +0,0 @@
1
- from bit_transformer import example_training_step
2
-
3
- if __name__ == "__main__":
4
- loss, telemetry = example_training_step()
5
- print("Training loss:", loss)
6
- print("Available telemetry:", list(telemetry.keys()))
 
 
 
 
 
 
 
full_bits_train.py DELETED
@@ -1,51 +0,0 @@
1
- import pathlib
2
- import torch
3
- from bit_transformer import BitTransformerLM
4
-
5
- DATA_PATH = pathlib.Path('full_bits.pt')
6
-
7
- class BitSeq(torch.utils.data.IterableDataset):
8
- def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
9
- self.bits = torch.load(path, mmap=True)
10
- self.seq = seq
11
-
12
- def __len__(self) -> int:
13
- return (self.bits.numel() // self.seq) - 1
14
-
15
- def __iter__(self):
16
- N = (self.bits.numel() // self.seq) - 1
17
- for i in range(N):
18
- s = i * self.seq
19
- yield (
20
- self.bits[s:s+self.seq].long(),
21
- self.bits[s+1:s+self.seq+1].long(),
22
- )
23
-
24
- def main() -> None:
25
- dl = torch.utils.data.DataLoader(
26
- BitSeq(DATA_PATH, seq=2048),
27
- batch_size=8,
28
- num_workers=0,
29
- pin_memory=False,
30
- )
31
-
32
- model = BitTransformerLM(
33
- d_model=64,
34
- nhead=4,
35
- num_layers=2,
36
- dim_feedforward=256,
37
- max_seq_len=2048,
38
- reversible=True,
39
- use_autocast=True,
40
- )
41
-
42
- loss_fn = torch.nn.CrossEntropyLoss()
43
- xb, yb = next(iter(dl))
44
- logits, _ = model(xb)
45
- pred = logits.reshape(-1, 2)
46
- target = yb.reshape(-1)
47
- loss = loss_fn(pred, target)
48
- print('Batch loss:', float(loss))
49
-
50
- if __name__ == '__main__':
51
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
integration_flow.py DELETED
@@ -1,110 +0,0 @@
1
- import torch
2
- from torch.profiler import profile
3
- from bit_transformer import (
4
- BitTransformerLM,
5
- quantize_dynamic,
6
- hil_safe_inference,
7
- collapse_submodel,
8
- )
9
- from bit_transformer.training import train_loop
10
- from bit_transformer.torch_utils import cpu_autocast
11
-
12
- def train(
13
- model: BitTransformerLM,
14
- data: torch.Tensor,
15
- epochs: int = 3,
16
- compress_prob: float = 0.5,
17
- direct_prob: float = 0.0,
18
- log: bool = False,
19
- forward_kwargs: dict | None = None,
20
- ) -> list[dict]:
21
- """Train on bit sequences with optional random compression.
22
-
23
- If ``direct_prob`` is positive, some batches are fed using their
24
- run-length encoded representation packed into bits. Loss on these
25
- direct-compressed batches is tracked separately.
26
-
27
- Returns a list of per-epoch metric dictionaries containing raw and
28
- compressed loss/accuracy statistics and the mean compression ratio.
29
- """
30
- return train_loop(
31
- model,
32
- data,
33
- epochs=epochs,
34
- compress_prob=compress_prob,
35
- direct_prob=direct_prob,
36
- log=log,
37
- forward_kwargs=forward_kwargs,
38
- )
39
-
40
-
41
- def main() -> None:
42
- data = torch.randint(0, 2, (64, 128), dtype=torch.long)
43
- validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long)
44
- input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long)
45
- bit_sequence_data = data.tolist()
46
-
47
- model = BitTransformerLM(
48
- d_model=32,
49
- nhead=4,
50
- num_layers=1,
51
- dim_feedforward=64,
52
- max_seq_len=128,
53
- use_act=True,
54
- act_threshold=0.7,
55
- reversible=True,
56
- chunk_size=128,
57
- )
58
-
59
- for step in range(1, 13):
60
- if step % 2 == 0:
61
- model = model.double_width()
62
- else:
63
- model = model.double_layers()
64
- train(model, data, epochs=3, compress_prob=0.5, log=True)
65
- _, telemetry = model(validation_bits)
66
- K = telemetry["negentropy_logits"].mean().item()
67
- C = telemetry["lz_complexity_logits"].mean().item()
68
- S = telemetry["symbiosis_score"].mean().item()
69
- assert (
70
- K > 0.3 and C > 0.35 and S > 0.5
71
- ), f"Step {step} telemetry floor failure"
72
-
73
- with cpu_autocast():
74
- model(input_bits)
75
-
76
- quantized_model = quantize_dynamic(model)
77
- quantized_model.eval()
78
-
79
- safe_output, _ = hil_safe_inference(
80
- quantized_model, input_bits, c_floor=0.35, s_floor=0.5
81
- )
82
-
83
- student_model, _ = collapse_submodel(
84
- bit_sequence_data,
85
- target_params=dict(
86
- d_model=16,
87
- nhead=4,
88
- num_layers=1,
89
- dim_feedforward=32,
90
- max_seq_len=128,
91
- ),
92
- floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5},
93
- )
94
-
95
- compiled_model = (
96
- torch.compile(student_model)
97
- if hasattr(torch, "compile")
98
- else student_model
99
- )
100
- compiled_model.eval()
101
-
102
- with profile() as prof:
103
- compiled_model(input_bits)
104
-
105
- prof.export_chrome_trace("trace12.json")
106
- print("Safe output bits:", safe_output.squeeze(0).tolist())
107
-
108
-
109
- if __name__ == "__main__":
110
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
integration_schedule.py DELETED
@@ -1,379 +0,0 @@
1
- import os
2
- import time
3
- import math
4
- from itertools import cycle
5
- from typing import Optional
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from bit_transformer import (
10
- BitTransformerLM,
11
- text_to_bits,
12
- quantize_dynamic,
13
- prepare_qat_fx,
14
- convert_qat_fx,
15
- hil_safe_inference,
16
- collapse_submodel,
17
- diffusion_inference,
18
- TelemetrySynthesizer,
19
- save_distilled_model,
20
- )
21
- from bit_transformer.training import train_loop as train
22
- from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
23
- from bit_transformer.utils import save_model, load_model, set_dropout
24
- from bit_transformer.torch_utils import cpu_autocast
25
-
26
-
27
- def lines_to_tensor(lines, max_len):
28
- seqs = []
29
- for text in lines:
30
- bits = text_to_bits(text)[:max_len]
31
- if len(bits) < max_len:
32
- bits.extend([0] * (max_len - len(bits)))
33
- seqs.append(bits)
34
- return torch.tensor(seqs, dtype=torch.long)
35
-
36
-
37
- def load_wikitext(dataset_size=128, max_len=64):
38
- try:
39
- from datasets import load_dataset
40
-
41
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
42
- train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
43
- valid_split = max(1, dataset_size // 4)
44
- valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
45
- train = lines_to_tensor(train_lines, max_len)
46
- valid = lines_to_tensor(valid_lines, max_len)
47
- return train, valid, train_lines
48
- except Exception as e:
49
- print("Dataset load failed, using random bits", e)
50
- train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
51
- valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
52
- return train, valid, ["" for _ in range(len(train))]
53
-
54
-
55
- def _warmup(
56
- model: BitTransformerLM,
57
- data: torch.Tensor,
58
- steps: int = 5,
59
- freeze_old: bool = False,
60
- old_layers: int = 0,
61
- *,
62
- diffusion: bool = False,
63
- curriculum: bool = False,
64
- optimizer: Optional[torch.optim.Optimizer] = None,
65
- scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
66
- ) -> None:
67
- """Run a short warm-up loop after expansion."""
68
- model.train()
69
- set_dropout(model, 0.1)
70
- if freeze_old:
71
- for idx, layer in enumerate(model.layers):
72
- if idx < old_layers:
73
- for p in layer.parameters():
74
- p.requires_grad_(False)
75
- if optimizer is None or scheduler is None:
76
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
77
- it = iter(data.split(8))
78
- for idx in range(steps):
79
- try:
80
- batch = next(it)
81
- except StopIteration:
82
- it = iter(data.split(8))
83
- batch = next(it)
84
- if diffusion:
85
- p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
86
- noise = (torch.rand_like(batch.float()) < p).long()
87
- noisy = batch ^ noise
88
- logits, _ = model(noisy, causal=False)
89
- pred = logits.reshape(-1, 2)
90
- target = batch.reshape(-1)
91
- else:
92
- logits, _ = model(batch)
93
- pred = logits[:, :-1, :].reshape(-1, 2)
94
- target = batch[:, 1:].reshape(-1)
95
- loss = F.cross_entropy(pred, target)
96
- loss.backward()
97
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
98
- optimizer.step()
99
- scheduler.step()
100
- optimizer.zero_grad()
101
- for p in model.parameters():
102
- p.requires_grad_(True)
103
- model.eval()
104
- set_dropout(model, 0.0)
105
-
106
-
107
- def integration_schedule(
108
- steps: int = 10,
109
- max_len: int = 64,
110
- dataset_size: int = 128,
111
- *,
112
- weights_path: str = "weights/model.pt.gz",
113
- plateau_steps: int = 0,
114
- collapsed_path: str | None = None,
115
- epochs_per_step: int = 2,
116
- extra_steps: int = 3,
117
- collapse: bool = True,
118
- diffusion: bool = False,
119
- noise_schedule: str = "linear",
120
- diffusion_steps: int = 8,
121
- diffusion_curriculum: bool = False,
122
- use_checkpoint: bool = True,
123
- reversible: bool = True,
124
- improve_thresh: float = 0.01,
125
- qat: bool = False,
126
- ):
127
- start = time.time()
128
- train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
129
- if os.path.exists(weights_path):
130
- try:
131
- model = load_model(weights_path)
132
- print(f"Loaded model from {weights_path}")
133
- except Exception as e:
134
- print("Failed to load weights, initializing new model", e)
135
- model = BitTransformerLM(
136
- d_model=32,
137
- nhead=4,
138
- num_layers=1,
139
- dim_feedforward=64,
140
- max_seq_len=max_len,
141
- use_act=True,
142
- act_threshold=0.7,
143
- reversible=reversible,
144
- chunk_size=max_len,
145
- use_autocast=True,
146
- use_checkpoint=use_checkpoint,
147
- )
148
- else:
149
- model = BitTransformerLM(
150
- d_model=32,
151
- nhead=4,
152
- num_layers=1,
153
- dim_feedforward=64,
154
- max_seq_len=max_len,
155
- use_act=True,
156
- act_threshold=0.7,
157
- reversible=reversible,
158
- chunk_size=max_len,
159
- use_autocast=True,
160
- use_checkpoint=use_checkpoint,
161
- )
162
- if qat:
163
- model = prepare_qat_fx(model)
164
- results = []
165
- scale_cycle = cycle(["layers", "width", "context"])
166
- base_lr = 1e-3
167
- prev_val_loss: Optional[float] = None
168
- for step in range(steps):
169
- model.train()
170
- set_dropout(model, 0.1)
171
- opt, sched = configure_optimizer(
172
- model, lr=base_lr, total_steps=epochs_per_step
173
- )
174
- train(
175
- model,
176
- train_bits,
177
- epochs=epochs_per_step,
178
- extra_steps=extra_steps,
179
- compress_prob=0.0 if diffusion else 1.0,
180
- log=True,
181
- diffusion=diffusion,
182
- diffusion_curriculum=diffusion_curriculum,
183
- optimizer=opt,
184
- scheduler=sched,
185
- )
186
-
187
- model.eval()
188
- set_dropout(model, 0.0)
189
- with torch.no_grad():
190
- logits, telemetry = model(valid_bits, causal=not diffusion)
191
- if diffusion:
192
- pred = logits.reshape(-1, 2)
193
- target = valid_bits.reshape(-1)
194
- else:
195
- pred = logits[:, :-1, :].reshape(-1, 2)
196
- target = valid_bits[:, 1:].reshape(-1)
197
- val_loss = F.cross_entropy(pred, target).item()
198
- k = telemetry["negentropy_logits"].mean().item()
199
- c = telemetry["lz_complexity_logits"].mean().item()
200
- s = telemetry["symbiosis_score"].mean().item()
201
- print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
202
- results.append((step, val_loss, k, c, s))
203
-
204
- if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
205
- strategy = next(scale_cycle)
206
- base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
207
- if strategy == "layers":
208
- old_layers = model.num_layers
209
- model = model.double_layers()
210
- warm_opt, warm_sched = configure_optimizer(
211
- model, lr=base_lr, total_steps=100
212
- )
213
- _warmup(
214
- model,
215
- train_bits,
216
- steps=100,
217
- freeze_old=True,
218
- old_layers=old_layers,
219
- diffusion=diffusion,
220
- curriculum=diffusion_curriculum,
221
- optimizer=warm_opt,
222
- scheduler=warm_sched,
223
- )
224
- elif strategy == "width":
225
- model = model.double_width()
226
- warm_opt, warm_sched = configure_optimizer(
227
- model, lr=base_lr, total_steps=100
228
- )
229
- _warmup(
230
- model,
231
- train_bits,
232
- steps=100,
233
- diffusion=diffusion,
234
- curriculum=diffusion_curriculum,
235
- optimizer=warm_opt,
236
- scheduler=warm_sched,
237
- )
238
- else:
239
- max_len *= 2
240
- train_bits, valid_bits, train_lines = load_wikitext(
241
- dataset_size, max_len
242
- )
243
- model = model.double_length()
244
- warm_opt, warm_sched = configure_optimizer(
245
- model, lr=base_lr, total_steps=100
246
- )
247
- _warmup(
248
- model,
249
- train_bits,
250
- steps=100,
251
- diffusion=diffusion,
252
- curriculum=diffusion_curriculum,
253
- optimizer=warm_opt,
254
- scheduler=warm_sched,
255
- )
256
-
257
- prev_val_loss = val_loss
258
- if time.time() - start > 8 * 60:
259
- print("Time limit reached")
260
- break
261
-
262
- # optional plateau phase at final size
263
- for p in range(plateau_steps):
264
- model.train()
265
- set_dropout(model, 0.1)
266
- train(
267
- model,
268
- train_bits,
269
- epochs=epochs_per_step,
270
- extra_steps=extra_steps,
271
- compress_prob=0.0 if diffusion else 1.0,
272
- log=True,
273
- diffusion=diffusion,
274
- diffusion_curriculum=diffusion_curriculum,
275
- )
276
- model.eval()
277
- set_dropout(model, 0.0)
278
- with torch.no_grad():
279
- logits, telemetry = model(valid_bits, causal=not diffusion)
280
- if diffusion:
281
- pred = logits.reshape(-1, 2)
282
- target = valid_bits.reshape(-1)
283
- else:
284
- pred = logits[:, :-1, :].reshape(-1, 2)
285
- target = valid_bits[:, 1:].reshape(-1)
286
- val_loss = F.cross_entropy(pred, target).item()
287
- k = telemetry["negentropy_logits"].mean().item()
288
- c = telemetry["lz_complexity_logits"].mean().item()
289
- s = telemetry["symbiosis_score"].mean().item()
290
- idx = steps + p
291
- print(
292
- f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
293
- )
294
- results.append((idx, val_loss, k, c, s))
295
- if time.time() - start > 8 * 60:
296
- print("Time limit reached")
297
- break
298
-
299
- # final validation after last step
300
- model.eval()
301
- set_dropout(model, 0.0)
302
- with torch.no_grad():
303
- logits, telemetry = model(valid_bits, causal=not diffusion)
304
- if diffusion:
305
- pred = logits.reshape(-1, 2)
306
- target = valid_bits.reshape(-1)
307
- else:
308
- pred = logits[:, :-1, :].reshape(-1, 2)
309
- target = valid_bits[:, 1:].reshape(-1)
310
- val_loss = F.cross_entropy(pred, target).item()
311
- k = telemetry["negentropy_logits"].mean().item()
312
- c = telemetry["lz_complexity_logits"].mean().item()
313
- s = telemetry["symbiosis_score"].mean().item()
314
-
315
- print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
316
- results.append((steps + plateau_steps, val_loss, k, c, s))
317
-
318
- # persist final model weights for future runs
319
- save_model(model, weights_path)
320
-
321
- input_bits = valid_bits[:1]
322
- if qat:
323
- qmodel = convert_qat_fx(model)
324
- else:
325
- with cpu_autocast():
326
- model(input_bits)
327
- qmodel = quantize_dynamic(model)
328
- qmodel.eval()
329
- try:
330
- hil_safe_inference(
331
- qmodel,
332
- input_bits,
333
- c_floor=0.3,
334
- s_floor=0.5,
335
- causal=not diffusion,
336
- strict=not diffusion,
337
- )
338
- except RuntimeError as e:
339
- print("Safety gate triggered", e)
340
- collapsed = None
341
- if collapse:
342
- synth = TelemetrySynthesizer(n_clusters=8)
343
- reps = synth.cluster_sequences(model, train_bits[:64])
344
- floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
345
- collapsed, metrics = collapse_submodel(
346
- reps,
347
- target_params=dict(
348
- d_model=16,
349
- nhead=4,
350
- num_layers=1,
351
- dim_feedforward=32,
352
- max_seq_len=max_len,
353
- ),
354
- floors=floors,
355
- )
356
- collapsed.eval()
357
- with torch.no_grad():
358
- logits, _ = collapsed(valid_bits)
359
- pred = logits[:, :-1, :].reshape(-1, 2)
360
- target = valid_bits[:, 1:].reshape(-1)
361
- c_loss = F.cross_entropy(pred, target).item()
362
- print("Collapsed model validation loss:", c_loss)
363
- if collapsed_path is not None:
364
- save_distilled_model(
365
- collapsed,
366
- collapsed_path,
367
- {**metrics, "val_loss": c_loss},
368
- floors=floors,
369
- )
370
- if diffusion:
371
- sample = diffusion_inference(
372
- model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
373
- )
374
- print("Diffusion sample:", sample[0].tolist())
375
- return results, collapsed
376
-
377
-
378
- if __name__ == "__main__":
379
- integration_schedule()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
progressive_scaleup.py DELETED
@@ -1,216 +0,0 @@
1
- """Legacy progressive scale-up demo.
2
-
3
- This script is retained for historical reference but has been superseded by
4
- ``integration_schedule.py`` which provides a more flexible scaling workflow.
5
- """
6
-
7
- import argparse
8
- import warnings
9
- import torch
10
- import torch.nn.functional as F
11
- from bit_transformer import (
12
- BitTransformerLM,
13
- configure_optimizer,
14
- expand_model,
15
- text_to_bits,
16
- )
17
- from bit_transformer.training import train_loop as basic_train
18
-
19
- warnings.warn(
20
- "progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
21
- DeprecationWarning,
22
- stacklevel=2,
23
- )
24
-
25
-
26
- def progressive_scale_up(
27
- eps: float = 0.65,
28
- steps: int = 2,
29
- width_mult: float = 1.0,
30
- forward_kwargs: dict | None = None,
31
- ) -> None:
32
- """Demonstrate automatic scaling of the model on random data."""
33
- params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
34
- model = BitTransformerLM(**params)
35
- steps_per_epoch = 64 // 8
36
- optimizer, scheduler = configure_optimizer(
37
- model, lr=1e-3, total_steps=steps * steps_per_epoch
38
- )
39
-
40
- train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
41
- valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
42
-
43
- for step in range(steps):
44
- # one epoch over train
45
- basic_train(
46
- model,
47
- train,
48
- epochs=1,
49
- compress_prob=0.5,
50
- log=False,
51
- forward_kwargs=forward_kwargs,
52
- )
53
-
54
- with torch.no_grad():
55
- logits, _ = model(valid, **(forward_kwargs or {}))
56
- pred = logits[:, :-1, :].reshape(-1, 2)
57
- target = valid[:, 1:].reshape(-1)
58
- val_loss = F.cross_entropy(pred, target).item()
59
- print(f"Step {step} validation loss: {val_loss:.4f}")
60
- if val_loss < eps:
61
- params["num_layers"] *= 2
62
- params["d_model"] = int(params["d_model"] * width_mult)
63
- params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
64
- model = expand_model(model, params)
65
- optimizer, scheduler = configure_optimizer(
66
- model, lr=1e-3, total_steps=steps * steps_per_epoch
67
- )
68
- print(
69
- "Scaled model to", params["num_layers"], "layers and width", params["d_model"]
70
- )
71
-
72
-
73
- def progressive_scale_up_text(
74
- improve_thresh: float = 0.01,
75
- steps: int = 2,
76
- width_mult: float = 2.0,
77
- max_len: int = 64,
78
- dataset_size: int = 512,
79
- forward_kwargs: dict | None = None,
80
- ) -> None:
81
- """Scale up using WikiText2 lines converted to bits.
82
-
83
- Parameters
84
- ----------
85
- improve_thresh: float
86
- Relative validation loss improvement required to avoid scaling.
87
- If improvement is <= this threshold, model size is increased.
88
- steps: int
89
- Number of training steps.
90
- width_mult: float
91
- Multiplier applied when increasing model width.
92
- max_len: int
93
- Initial sequence length.
94
- dataset_size: int
95
- Number of training lines to load from WikiText2.
96
- forward_kwargs: dict | None
97
- Extra keyword arguments for the forward pass.
98
- """
99
- from datasets import load_dataset
100
-
101
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
102
- train_iter = ds["train"]["text"]
103
- valid_iter = ds["validation"]["text"]
104
-
105
- train_lines = []
106
- for line in train_iter:
107
- train_lines.append(line)
108
- if len(train_lines) >= dataset_size:
109
- break
110
-
111
- valid_lines = []
112
- for line in valid_iter:
113
- valid_lines.append(line)
114
- if len(valid_lines) >= dataset_size // 4:
115
- break
116
-
117
- def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
118
- seqs = []
119
- for text in lines:
120
- bits = text_to_bits(text)[:length]
121
- if len(bits) < length:
122
- bits.extend([0] * (length - len(bits)))
123
- seqs.append(bits)
124
- return torch.tensor(seqs, dtype=torch.long)
125
-
126
- train = lines_to_tensor(train_lines, max_len)
127
- valid = lines_to_tensor(valid_lines, max_len)
128
-
129
- params = dict(
130
- d_model=32,
131
- nhead=4,
132
- num_layers=1,
133
- dim_feedforward=64,
134
- max_seq_len=max_len,
135
- )
136
- model = BitTransformerLM(**params)
137
- steps_per_epoch = len(train) // 8
138
- optimizer, scheduler = configure_optimizer(
139
- model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
140
- )
141
-
142
- prev_loss: float | None = None
143
- scale_length = True
144
-
145
- for step in range(steps):
146
- basic_train(
147
- model,
148
- train,
149
- epochs=1,
150
- compress_prob=0.5,
151
- log=False,
152
- forward_kwargs=forward_kwargs,
153
- )
154
-
155
- with torch.no_grad():
156
- logits, _ = model(valid, **(forward_kwargs or {}))
157
- pred = logits[:, :-1, :].reshape(-1, 2)
158
- target = valid[:, 1:].reshape(-1)
159
- val_loss = F.cross_entropy(pred, target).item()
160
- print(f"Step {step} validation loss: {val_loss:.4f}")
161
- if prev_loss is not None:
162
- improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
163
- if improvement <= improve_thresh:
164
- if scale_length:
165
- params["max_seq_len"] *= 2
166
- train = lines_to_tensor(train_lines, params["max_seq_len"])
167
- valid = lines_to_tensor(valid_lines, params["max_seq_len"])
168
- model = model.double_length()
169
- steps_per_epoch = len(train) // 8
170
- scale_type = "length"
171
- else:
172
- params["d_model"] = int(params["d_model"] * width_mult)
173
- params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
174
- model = expand_model(model, params)
175
- scale_type = "width"
176
- optimizer, scheduler = configure_optimizer(
177
- model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
178
- )
179
- scale_length = not scale_length
180
- param_count = sum(p.numel() for p in model.parameters())
181
- print(
182
- f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
183
- )
184
- prev_loss = val_loss
185
-
186
-
187
- if __name__ == "__main__":
188
- parser = argparse.ArgumentParser(description="Progressively scale model length and width")
189
- parser.add_argument("--steps", type=int, default=2, help="number of training steps")
190
- parser.add_argument(
191
- "--improve-thresh",
192
- type=float,
193
- default=0.01,
194
- help="relative loss improvement required to avoid scaling",
195
- )
196
- parser.add_argument(
197
- "--width-mult", type=float, default=2.0, help="width multiplier when scaling"
198
- )
199
- parser.add_argument("--causal", action="store_true", help="use causal attention during training")
200
- parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
201
- args = parser.parse_args()
202
- if args.wikitext:
203
- progressive_scale_up_text(
204
- improve_thresh=args.improve_thresh,
205
- steps=args.steps,
206
- width_mult=args.width_mult,
207
- forward_kwargs={"causal": args.causal} if args.causal else None,
208
- )
209
- else:
210
- progressive_scale_up(
211
- eps=args.improve_thresh,
212
- steps=args.steps,
213
- width_mult=args.width_mult,
214
- forward_kwargs={"causal": args.causal} if args.causal else None,
215
- )
216
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
quick_training_run.py DELETED
@@ -1,339 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Full end-to-end BitTransformerLM training run with all optimizations!
4
- Small scale test to validate our enhanced system.
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.optim as optim
10
- from torch.utils.data import Dataset, DataLoader
11
- import numpy as np
12
- import logging
13
- from pathlib import Path
14
- import time
15
- from typing import List, Dict, Any
16
-
17
- # Import our enhanced modules
18
- from bit_transformer.model import BitTransformerLM
19
- from bit_transformer.compression import compress_bits_batch, model_output_decompress
20
- from bit_transformer.error_handling import safe_model_forward, setup_error_logging
21
- from bit_transformer.types import BitSequence, TelemetryDict
22
- from enhanced_checkpoint_system import create_checkpoint_manager
23
-
24
- # Setup logging
25
- logger = setup_error_logging("INFO")
26
-
27
- class SimpleBitDataset(Dataset):
28
- """Simple dataset of bit sequences for training."""
29
-
30
- def __init__(self, num_samples: int = 1000, seq_length: int = 128):
31
- self.num_samples = num_samples
32
- self.seq_length = seq_length
33
- self.data = self._generate_bit_sequences()
34
-
35
- def _generate_bit_sequences(self) -> List[torch.Tensor]:
36
- """Generate diverse bit sequences with different patterns."""
37
- sequences = []
38
-
39
- # Pattern 1: Alternating sequences
40
- for i in range(self.num_samples // 4):
41
- pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long)
42
- sequences.append(pattern)
43
-
44
- # Pattern 2: Random sequences
45
- for i in range(self.num_samples // 4):
46
- pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long)
47
- sequences.append(pattern)
48
-
49
- # Pattern 3: Structured patterns (runs)
50
- for i in range(self.num_samples // 4):
51
- pattern = []
52
- pos = 0
53
- while pos < self.seq_length:
54
- run_length = min(np.random.randint(1, 20), self.seq_length - pos)
55
- bit_value = np.random.randint(0, 2)
56
- pattern.extend([bit_value] * run_length)
57
- pos += run_length
58
- pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long)
59
- sequences.append(pattern)
60
-
61
- # Pattern 4: Fibonacci-like sequences
62
- remaining = self.num_samples - len(sequences)
63
- for i in range(remaining):
64
- pattern = [0, 1]
65
- while len(pattern) < self.seq_length:
66
- pattern.append(pattern[-1] ^ pattern[-2]) # XOR of last two bits
67
- pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long)
68
- sequences.append(pattern)
69
-
70
- return sequences
71
-
72
- def __len__(self):
73
- return len(self.data)
74
-
75
- def __getitem__(self, idx):
76
- sequence = self.data[idx]
77
- # For language modeling, input is sequence[:-1], target is sequence[1:]
78
- return sequence[:-1], sequence[1:]
79
-
80
-
81
- def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
82
- """Compute K/C/S safety metrics."""
83
- pred_bits = (predictions > 0.5).float().flatten()
84
-
85
- # K metric (Negentropy): Measure of order vs randomness
86
- if len(pred_bits) > 0:
87
- prob_1 = pred_bits.mean().item()
88
- prob_0 = 1 - prob_1
89
- if prob_0 > 0 and prob_1 > 0:
90
- entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1)
91
- negentropy = 1.0 - entropy # Higher = more ordered
92
- else:
93
- negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0
94
- else:
95
- negentropy = 0.0
96
-
97
- # C metric (Complexity): Simple run-length approximation
98
- changes = (pred_bits[1:] != pred_bits[:-1]).sum().item()
99
- complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0
100
-
101
- # S metric (Symbiosis): Alignment with target distribution
102
- target_bits = targets.float().flatten()
103
- if len(target_bits) > 0:
104
- target_mean = target_bits.mean()
105
- pred_mean = pred_bits.mean()
106
- symbiosis = 1.0 - abs(target_mean - pred_mean).item()
107
- else:
108
- symbiosis = 1.0
109
-
110
- return {
111
- 'K_negentropy': negentropy,
112
- 'C_complexity': complexity,
113
- 'S_symbiosis': symbiosis
114
- }
115
-
116
-
117
- def train_bittransformer():
118
- """Main training function with all optimizations."""
119
-
120
- logger.info("🚀 Starting BitTransformerLM end-to-end training run!")
121
-
122
- # Model configuration - small but meaningful
123
- model_config = {
124
- 'd_model': 256,
125
- 'nhead': 8,
126
- 'num_layers': 4,
127
- 'dim_feedforward': 512,
128
- 'max_seq_len': 128,
129
- 'use_checkpoint': True,
130
- 'chunk_size': None, # Disable chunking for small model
131
- }
132
-
133
- training_config = {
134
- 'batch_size': 16,
135
- 'learning_rate': 1e-3,
136
- 'num_epochs': 10,
137
- 'save_every_n_epochs': 2,
138
- 'log_every_n_steps': 10
139
- }
140
-
141
- # Initialize enhanced checkpoint manager
142
- checkpoint_manager = create_checkpoint_manager()
143
- session_id = checkpoint_manager.create_training_session(
144
- session_name="end_to_end_test",
145
- model_config=model_config,
146
- training_config=training_config
147
- )
148
-
149
- logger.info(f"📝 Created training session: {session_id}")
150
-
151
- # Create dataset and dataloader
152
- logger.info("📊 Creating training dataset...")
153
- dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len'])
154
- dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True)
155
-
156
- # Initialize model
157
- logger.info("🧠 Initializing BitTransformerLM model...")
158
- model = BitTransformerLM(
159
- d_model=model_config['d_model'],
160
- nhead=model_config['nhead'],
161
- num_layers=model_config['num_layers'],
162
- dim_feedforward=model_config['dim_feedforward'],
163
- max_seq_len=model_config['max_seq_len'],
164
- use_checkpoint=model_config['use_checkpoint'],
165
- chunk_size=model_config['chunk_size']
166
- )
167
-
168
- # Count parameters
169
- total_params = sum(p.numel() for p in model.parameters())
170
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
171
- logger.info(f"🔢 Model parameters: {total_params:,} total, {trainable_params:,} trainable")
172
-
173
- # Setup optimizer and loss
174
- optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate'])
175
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs'])
176
- criterion = nn.CrossEntropyLoss()
177
-
178
- # Training loop
179
- logger.info("🏃‍♂️ Starting training loop...")
180
-
181
- for epoch in range(training_config['num_epochs']):
182
- model.train()
183
- epoch_loss = 0.0
184
- epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0}
185
- num_batches = 0
186
-
187
- start_time = time.time()
188
-
189
- for batch_idx, (inputs, targets) in enumerate(dataloader):
190
- optimizer.zero_grad()
191
-
192
- # Forward pass with safety monitoring
193
- try:
194
- # BitTransformerLM returns (logits, telemetry)
195
- output = safe_model_forward(model, inputs)
196
- if isinstance(output, tuple):
197
- logits, telemetry = output
198
- else:
199
- logits = output
200
- telemetry = {}
201
-
202
- # BitTransformerLM outputs logits for binary classification
203
- # Shape should be [batch, seq_len, 2] for binary vocab
204
- if logits.dim() == 2:
205
- # If [batch*seq_len, 2], already flattened
206
- logits_flat = logits
207
- targets_flat = targets.reshape(-1)
208
- else:
209
- # If [batch, seq_len, 2], flatten
210
- logits_flat = logits.reshape(-1, 2)
211
- targets_flat = targets.reshape(-1)
212
-
213
- loss = criterion(logits_flat, targets_flat)
214
-
215
- # Backward pass
216
- loss.backward()
217
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
218
- optimizer.step()
219
-
220
- # Compute metrics
221
- with torch.no_grad():
222
- # Handle different logits shapes for predictions
223
- if logits.dim() == 2:
224
- # [batch*seq_len, 2] -> reshape back to [batch, seq_len, 2]
225
- batch_size = inputs.shape[0]
226
- seq_len = inputs.shape[1]
227
- logits_reshaped = logits.reshape(batch_size, seq_len, 2)
228
- predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] # Prob of bit=1
229
- else:
230
- # [batch, seq_len, 2]
231
- predictions = torch.softmax(logits, dim=-1)[:, :, 1] # Prob of bit=1
232
-
233
- safety_metrics = compute_safety_metrics(predictions, targets)
234
-
235
- epoch_loss += loss.item()
236
- for key, value in safety_metrics.items():
237
- epoch_metrics[key] += value
238
- num_batches += 1
239
-
240
- # Logging
241
- if batch_idx % training_config['log_every_n_steps'] == 0:
242
- logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, "
243
- f"Batch {batch_idx}/{len(dataloader)}, "
244
- f"Loss: {loss.item():.4f}, "
245
- f"K: {safety_metrics['K_negentropy']:.3f}, "
246
- f"C: {safety_metrics['C_complexity']:.3f}, "
247
- f"S: {safety_metrics['S_symbiosis']:.3f}")
248
-
249
- except Exception as e:
250
- logger.error(f"Error in batch {batch_idx}: {e}")
251
- continue
252
-
253
- # End of epoch processing
254
- scheduler.step()
255
- epoch_time = time.time() - start_time
256
-
257
- if num_batches > 0:
258
- avg_loss = epoch_loss / num_batches
259
- avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()}
260
-
261
- logger.info(f"✅ Epoch {epoch+1} completed in {epoch_time:.2f}s")
262
- logger.info(f"📊 Avg Loss: {avg_loss:.4f}")
263
- logger.info(f"📈 Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, "
264
- f"C: {avg_metrics['C_complexity']:.3f}, "
265
- f"S: {avg_metrics['S_symbiosis']:.3f}")
266
-
267
- # Save checkpoint
268
- if (epoch + 1) % training_config['save_every_n_epochs'] == 0:
269
- checkpoint_success = checkpoint_manager.save_checkpoint(
270
- model=model,
271
- session_id=session_id,
272
- epoch=epoch + 1,
273
- metrics={
274
- 'loss': avg_loss,
275
- 'learning_rate': scheduler.get_last_lr()[0],
276
- **avg_metrics
277
- },
278
- optimizer_state=optimizer.state_dict(),
279
- scheduler_state=scheduler.state_dict()
280
- )
281
-
282
- if checkpoint_success:
283
- logger.info(f"💾 Checkpoint saved for epoch {epoch+1}")
284
-
285
- # Save best model if loss improved
286
- checkpoint_manager.save_best_model(
287
- session_id=session_id,
288
- model=model,
289
- metric_name='loss',
290
- metric_value=avg_loss,
291
- is_better_func=lambda x, y: x < y # Lower loss is better
292
- )
293
-
294
- logger.info("🎉 Training completed successfully!")
295
-
296
- # Test inference and compression
297
- logger.info("🧪 Testing model inference and compression...")
298
-
299
- model.eval()
300
- with torch.no_grad():
301
- # Create a test sequence
302
- test_input = torch.randint(0, 2, (1, 64), dtype=torch.long)
303
- logger.info(f"📥 Input sequence: {test_input.squeeze().tolist()}")
304
-
305
- # Model inference
306
- output_logits = model(test_input)
307
- output_probs = torch.softmax(output_logits, dim=-1)
308
- predicted_bits = torch.argmax(output_probs, dim=-1)
309
-
310
- logger.info(f"📤 Predicted sequence: {predicted_bits.squeeze().tolist()}")
311
-
312
- # Test compression
313
- compressed = compress_bits_batch(predicted_bits)
314
- logger.info(f"🗜️ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})")
315
-
316
- # Decompress to verify
317
- decompressed = model_output_decompress(compressed)
318
- compression_success = torch.equal(predicted_bits, decompressed)
319
- logger.info(f"✅ Compression/decompression successful: {compression_success}")
320
-
321
- # Final storage usage report
322
- storage_usage = checkpoint_manager.get_storage_usage()
323
- logger.info(f"💾 Final storage usage: {storage_usage['total_gb']:.3f} GB")
324
- logger.info(f"📁 Training sessions: {storage_usage['num_sessions']}")
325
-
326
- return session_id, model, checkpoint_manager
327
-
328
-
329
- if __name__ == "__main__":
330
- try:
331
- session_id, trained_model, manager = train_bittransformer()
332
- print(f"\n🎉 SUCCESS! Training session completed: {session_id}")
333
- print(f"🔍 Use checkpoint_manager.load_checkpoint('{session_id}') to resume")
334
-
335
- except Exception as e:
336
- logger.error(f"❌ Training failed: {e}")
337
- import traceback
338
- traceback.print_exc()
339
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/tools/sync_to_hf.py CHANGED
@@ -6,6 +6,7 @@ Uploads all cleaned documentation and code with proper commit message.
6
 
7
  import os
8
  import logging
 
9
  from pathlib import Path
10
  from huggingface_hub import HfApi, login
11
  from typing import Optional, List
@@ -14,6 +15,37 @@ from typing import Optional, List
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_files_to_sync(repo_root: Path) -> List[Path]:
18
  """Get the exact list of files that will be synced to HuggingFace."""
19
  # Files and directories to upload (excluding unnecessary files)
@@ -44,7 +76,7 @@ def get_files_to_sync(repo_root: Path) -> List[Path]:
44
  ".pytest_cache/**",
45
  ".ipynb_checkpoints/**",
46
  "weights/**",
47
- "checkpoints/**",
48
  "*.log",
49
  "*.pt", # Model weights
50
  "*.zip", # Backup files
@@ -130,6 +162,26 @@ def sync_repository_to_hf(
130
  files_to_upload = get_files_to_sync(repo_root)
131
  logger.info(f"Found {len(files_to_upload)} files to upload")
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # If preview only, just show the files and return
134
  if preview_only:
135
  preview_sync(repo_root)
 
6
 
7
  import os
8
  import logging
9
+ import re
10
  from pathlib import Path
11
  from huggingface_hub import HfApi, login
12
  from typing import Optional, List
 
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
  logger = logging.getLogger(__name__)
17
 
18
+ def scan_for_secrets(file_path: Path) -> List[str]:
19
+ """Scan a file for potential secrets and tokens."""
20
+ secrets_found = []
21
+
22
+ # Patterns for common secrets
23
+ secret_patterns = {
24
+ 'HuggingFace Token': r'hf_[A-Za-z0-9_]{30,}',
25
+ 'OpenAI API Key': r'sk-[A-Za-z0-9]{48}',
26
+ 'GitHub Token': r'gh[pousr]_[A-Za-z0-9_]{36,}',
27
+ 'AWS Access Key': r'AKIA[0-9A-Z]{16}',
28
+ 'Generic API Key': r'["\']?[Aa]pi[_-]?[Kk]ey["\']?\s*[:=]\s*["\']?[A-Za-z0-9_\-]{20,}["\']?',
29
+ 'Generic Token': r'["\']?[Tt]oken["\']?\s*[:=]\s*["\']?[A-Za-z0-9_\-]{20,}["\']?',
30
+ 'Generic Secret': r'["\']?[Ss]ecret["\']?\s*[:=]\s*["\']?[A-Za-z0-9_\-]{20,}["\']?',
31
+ }
32
+
33
+ try:
34
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
35
+ content = f.read()
36
+
37
+ for secret_type, pattern in secret_patterns.items():
38
+ matches = re.finditer(pattern, content, re.IGNORECASE)
39
+ for match in matches:
40
+ line_num = content[:match.start()].count('\n') + 1
41
+ secrets_found.append(f"{secret_type} found at line {line_num}: {match.group()[:50]}...")
42
+
43
+ except Exception as e:
44
+ logger.warning(f"Could not scan {file_path} for secrets: {e}")
45
+
46
+ return secrets_found
47
+
48
+
49
  def get_files_to_sync(repo_root: Path) -> List[Path]:
50
  """Get the exact list of files that will be synced to HuggingFace."""
51
  # Files and directories to upload (excluding unnecessary files)
 
76
  ".pytest_cache/**",
77
  ".ipynb_checkpoints/**",
78
  "weights/**",
79
+ "checkpoints/**", # Contains potentially sensitive configs
80
  "*.log",
81
  "*.pt", # Model weights
82
  "*.zip", # Backup files
 
162
  files_to_upload = get_files_to_sync(repo_root)
163
  logger.info(f"Found {len(files_to_upload)} files to upload")
164
 
165
+ # CRITICAL SECURITY CHECK: Scan all files for secrets
166
+ logger.info("🔍 Scanning files for secrets and tokens...")
167
+ all_secrets = []
168
+ for file_path in files_to_upload:
169
+ secrets = scan_for_secrets(file_path)
170
+ if secrets:
171
+ relative_path = file_path.relative_to(repo_root)
172
+ all_secrets.extend([f"{relative_path}: {secret}" for secret in secrets])
173
+
174
+ if all_secrets:
175
+ logger.error("🚨 SECURITY ALERT: Secrets detected in files!")
176
+ logger.error("The following secrets were found and MUST be removed before sync:")
177
+ for secret in all_secrets:
178
+ logger.error(f" - {secret}")
179
+ logger.error("❌ SYNC ABORTED for security reasons!")
180
+ logger.error("Please remove all secrets and use environment variables instead.")
181
+ return False
182
+
183
+ logger.info("✅ Security scan passed - no secrets detected")
184
+
185
  # If preview only, just show the files and return
186
  if preview_only:
187
  preview_sync(repo_root)
scripts/training/breakthrough_training.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BREAKTHROUGH BitTransformerLM Training Script
4
+ ===========================================
5
+
6
+ Using the ACTUAL BitTransformerLM model and training infrastructure,
7
+ configured for the Fixed RL Adafactor breakthrough results.
8
+ """
9
+
10
+ import sys
11
+ import os
12
+ import logging
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ from datasets import load_dataset
17
+ from huggingface_hub import login
18
+
19
+ # Add paths for imports
20
+ sys.path.append('/data')
21
+ sys.path.append('/data/BitTransformerLM')
22
+
23
+ from bit_transformer import (
24
+ BitTransformerLM,
25
+ text_to_bits,
26
+ train_loop,
27
+ save_model,
28
+ load_model,
29
+ set_dropout
30
+ )
31
+ from BTLM_Extensions import configure_adafactor_optimizer
32
+
33
+ # Setup logging
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(levelname)s - %(message)s',
37
+ handlers=[
38
+ logging.FileHandler('breakthrough_training.log'),
39
+ logging.StreamHandler()
40
+ ]
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+ def load_and_prepare_dataset():
45
+ """Load HF dataset and convert to bit tensors."""
46
+ logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
47
+
48
+ # Login to HuggingFace
49
+ hf_token = os.getenv('HF_TOKEN')
50
+ if hf_token:
51
+ login(token=hf_token)
52
+ else:
53
+ print("Warning: HF_TOKEN environment variable not set")
54
+
55
+ # Load dataset
56
+ dataset = load_dataset("WCNegentropy/BitTransformerLM")
57
+ train_data = dataset['train']
58
+
59
+ logger.info(f"Dataset loaded: {len(train_data)} samples")
60
+
61
+ # Process dataset - the HF dataset already has bit_sequence field!
62
+ bit_sequences = []
63
+ for sample in train_data:
64
+ if 'bit_sequence' in sample and sample['bit_sequence'] is not None:
65
+ # The bit_sequence might already be a list
66
+ bits = sample['bit_sequence']
67
+ if isinstance(bits, str):
68
+ try:
69
+ bits = eval(bits) # Convert string representation to list
70
+ except:
71
+ bits = None
72
+ if isinstance(bits, list) and len(bits) > 0:
73
+ bit_sequences.append(bits)
74
+ else:
75
+ # Fallback: convert original_text to bits
76
+ text = sample.get('original_text', '')
77
+ if text:
78
+ bits = text_to_bits(text)
79
+ bit_sequences.append(bits)
80
+ else:
81
+ # Fallback: convert text to bits
82
+ text = sample.get('text', '') or sample.get('original_text', '')
83
+ if text:
84
+ bits = text_to_bits(text)
85
+ bit_sequences.append(bits)
86
+
87
+ logger.info(f"Processed {len(bit_sequences)} bit sequences")
88
+
89
+ # Create training tensors with proper sequence length
90
+ max_len = 512 # BitTransformerLM default max_seq_len
91
+ training_sequences = []
92
+
93
+ for bits in bit_sequences:
94
+ # Split long sequences into chunks
95
+ for i in range(0, len(bits) - max_len + 1, max_len // 2):
96
+ seq = bits[i:i + max_len]
97
+ if len(seq) == max_len: # Only use full-length sequences
98
+ training_sequences.append(seq)
99
+
100
+ # Convert to tensor
101
+ data_tensor = torch.tensor(training_sequences, dtype=torch.long)
102
+ logger.info(f"Created training tensor: {data_tensor.shape}")
103
+
104
+ return data_tensor
105
+
106
+ def create_breakthrough_model():
107
+ """Create the EXACT breakthrough BitTransformerLM configuration."""
108
+ logger.info("Creating breakthrough BitTransformerLM model...")
109
+
110
+ # EXACT breakthrough configuration using ACTUAL BitTransformerLM parameters
111
+ model = BitTransformerLM(
112
+ d_model=512, # Breakthrough config
113
+ nhead=16, # 16 attention heads
114
+ num_layers=8, # 8 layers for ~16M params
115
+ dim_feedforward=1024, # 2x d_model
116
+ max_seq_len=512, # Match data preparation
117
+ reversible=True, # Memory efficiency
118
+ use_checkpoint=True, # Gradient checkpointing
119
+ use_autocast=True, # Mixed precision
120
+ use_act=True, # Adaptive Computation Time
121
+ act_threshold=0.9,
122
+ lambda_K=0.05, # Safety telemetry weights
123
+ lambda_C=0.05,
124
+ lambda_S=0.05
125
+ )
126
+
127
+ # Calculate parameter count
128
+ total_params = sum(p.numel() for p in model.parameters())
129
+ logger.info(f"Model created: {total_params:,} parameters")
130
+ logger.info(f"Target: ~16M parameters - {'✓' if 15_000_000 <= total_params <= 17_000_000 else '✗'}")
131
+
132
+ return model
133
+
134
+ def main():
135
+ """Main training function."""
136
+ logger.info("🚀 STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
137
+ logger.info("Using ACTUAL BitTransformerLM model and train_loop")
138
+
139
+ # Load dataset
140
+ data = load_and_prepare_dataset()
141
+
142
+ # Create model
143
+ model = create_breakthrough_model()
144
+
145
+ # CRITICAL: Use Fixed RL Adafactor (the breakthrough secret!)
146
+ logger.info("Configuring Fixed RL Adafactor optimizer...")
147
+ optimizer, scheduler = configure_adafactor_optimizer(
148
+ model,
149
+ lr=1e-3, # FIXED learning rate - key to breakthrough!
150
+ weight_decay=0.01,
151
+ total_steps=5000 # Estimated total steps
152
+ )
153
+ logger.info("Fixed RL Adafactor configured with LR=0.001")
154
+
155
+ # Training configuration
156
+ training_config = {
157
+ 'epochs': 20, # Reasonable number of epochs
158
+ 'batch_size': 4, # Adjust based on memory
159
+ 'accum_steps': 4, # Gradient accumulation
160
+ 'amp': True, # Mixed precision
161
+ 'log': True, # Enable logging
162
+ 'compress_prob': 0.0, # Start with no compression
163
+ 'optimizer': optimizer,
164
+ 'scheduler': scheduler
165
+ }
166
+
167
+ logger.info(f"Training configuration: {training_config}")
168
+ logger.info("Starting training loop...")
169
+
170
+ # Use the ACTUAL BitTransformerLM train_loop function!
171
+ metrics = train_loop(
172
+ model=model,
173
+ data=data,
174
+ **training_config
175
+ )
176
+
177
+ # Save the trained model
178
+ checkpoint_dir = Path('/data/BitTransformerLM/checkpoints')
179
+ checkpoint_dir.mkdir(exist_ok=True)
180
+
181
+ model_path = checkpoint_dir / 'breakthrough_model.pt'
182
+ save_model(model, model_path)
183
+ logger.info(f"Model saved to: {model_path}")
184
+
185
+ # Log final metrics
186
+ if metrics:
187
+ final_metrics = metrics[-1]
188
+ logger.info("🎉 TRAINING COMPLETED!")
189
+ logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}")
190
+ logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}")
191
+
192
+ # Check for breakthrough performance
193
+ if final_metrics['raw_loss'] < 3.0:
194
+ logger.info("🚀 BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
195
+
196
+ logger.info("Breakthrough training completed successfully!")
197
+
198
+ if __name__ == "__main__":
199
+ main()
scripts/training/final_breakthrough_training.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Final Breakthrough BitTransformerLM Training Script
4
+ =================================================
5
+
6
+ The complete training script using the ACTUAL BitTransformerLM model
7
+ with the breakthrough Fixed RL Adafactor configuration and full
8
+ HuggingFace dataset support with checkpoint resumption.
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import json
14
+ import logging
15
+ from pathlib import Path
16
+ from datetime import datetime
17
+ from typing import Optional, Dict, Any
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from datasets import load_dataset
22
+ from huggingface_hub import login
23
+
24
+ # Add paths for imports
25
+ sys.path.append('/data')
26
+ sys.path.append('/data/BitTransformerLM')
27
+
28
+ from bit_transformer import BitTransformerLM, text_to_bits
29
+ from BTLM_Extensions import configure_adafactor_optimizer
30
+
31
+ # Setup logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(levelname)s - %(message)s',
35
+ handlers=[
36
+ logging.FileHandler('/data/BitTransformerLM/breakthrough_training.log'),
37
+ logging.StreamHandler()
38
+ ]
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+
42
+ class BreakthroughTrainer:
43
+ """Production-grade BitTransformerLM trainer with breakthrough configuration."""
44
+
45
+ def __init__(self, config: Dict[str, Any]):
46
+ self.config = config
47
+ self.device = torch.device('cpu') # CPU training as per breakthrough
48
+ self.model = None
49
+ self.optimizer = None
50
+ self.scheduler = None
51
+ self.dataset = None
52
+ self.checkpoint_dir = Path(config['checkpoint_dir'])
53
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Training state
56
+ self.current_epoch = 0
57
+ self.total_steps = 0
58
+ self.best_loss = float('inf')
59
+ self.training_history = []
60
+
61
+ def load_and_prepare_dataset(self):
62
+ """Load HF dataset and convert to proper bit tensors."""
63
+ logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
64
+
65
+ # Login to HuggingFace
66
+ login(token=self.config['hf_token'])
67
+
68
+ # Load dataset
69
+ dataset = load_dataset("WCNegentropy/BitTransformerLM")
70
+ train_data = dataset['train']
71
+
72
+ logger.info(f"Dataset loaded: {len(train_data)} samples")
73
+
74
+ # Process dataset - convert to bits using the ACTUAL text_to_bits function
75
+ bit_sequences = []
76
+ for i, sample in enumerate(train_data):
77
+ if i % 1000 == 0:
78
+ logger.info(f"Processing sample {i}/{len(train_data)}")
79
+
80
+ # Try to get text from various fields
81
+ text = None
82
+ if 'original_text' in sample and sample['original_text']:
83
+ text = sample['original_text']
84
+ elif 'text' in sample and sample['text']:
85
+ text = sample['text']
86
+
87
+ if text and text.strip():
88
+ # Use ACTUAL text_to_bits function
89
+ bits = text_to_bits(text)
90
+ if len(bits) >= self.config['sequence_length']:
91
+ bit_sequences.append(bits)
92
+
93
+ logger.info(f"Processed {len(bit_sequences)} valid bit sequences")
94
+
95
+ # Create training sequences with proper length
96
+ seq_len = self.config['sequence_length']
97
+ training_sequences = []
98
+
99
+ for bits in bit_sequences:
100
+ # Create overlapping chunks
101
+ for i in range(0, len(bits) - seq_len + 1, seq_len // 2):
102
+ chunk = bits[i:i + seq_len]
103
+ if len(chunk) == seq_len:
104
+ training_sequences.append(chunk)
105
+
106
+ # Convert to tensor with proper dtype
107
+ self.dataset = torch.tensor(training_sequences, dtype=torch.long)
108
+ logger.info(f"Created training dataset: {self.dataset.shape}")
109
+
110
+ return self.dataset
111
+
112
+ def create_breakthrough_model(self):
113
+ """Create the EXACT breakthrough 16M parameter BitTransformerLM."""
114
+ logger.info("Creating breakthrough 16M parameter BitTransformerLM...")
115
+
116
+ # BREAKTHROUGH CONFIGURATION - exactly as identified before
117
+ self.model = BitTransformerLM(
118
+ d_model=512, # Breakthrough config
119
+ nhead=16, # 16 attention heads
120
+ num_layers=8, # 8 layers for ~16M params
121
+ dim_feedforward=1024, # 2x d_model
122
+ max_seq_len=self.config['sequence_length'],
123
+ lambda_K=0.05, # Safety telemetry weights
124
+ lambda_C=0.05,
125
+ lambda_S=0.05,
126
+ reversible=True, # Memory efficiency
127
+ use_checkpoint=True, # Gradient checkpointing
128
+ use_autocast=True, # CPU mixed precision
129
+ use_act=True, # Adaptive Computation Time
130
+ act_threshold=0.9
131
+ ).to(self.device)
132
+
133
+ # Calculate and verify parameter count
134
+ total_params = sum(p.numel() for p in self.model.parameters())
135
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
136
+
137
+ logger.info(f"Model created: {total_params:,} total parameters ({trainable_params:,} trainable)")
138
+ logger.info(f"Target: ~16M parameters - {'✓' if 15_000_000 <= total_params <= 17_000_000 else '✗'}")
139
+
140
+ return self.model
141
+
142
+ def setup_optimizer(self):
143
+ """Setup Fixed RL Adafactor optimizer (the breakthrough secret sauce)."""
144
+ logger.info("Setting up Fixed RL Adafactor optimizer...")
145
+
146
+ # Calculate total steps
147
+ steps_per_epoch = len(self.dataset) // self.config['batch_size']
148
+ total_steps = steps_per_epoch * self.config['num_epochs']
149
+
150
+ # CRITICAL: Use FIXED LR, not auto-LR (the breakthrough discovery!)
151
+ self.optimizer, self.scheduler = configure_adafactor_optimizer(
152
+ self.model,
153
+ lr=self.config['learning_rate'], # FIXED LR - key to breakthrough!
154
+ weight_decay=self.config['weight_decay'],
155
+ total_steps=total_steps
156
+ )
157
+
158
+ logger.info(f"Fixed RL Adafactor configured with LR={self.config['learning_rate']}")
159
+ logger.info(f"Total training steps: {total_steps}")
160
+
161
+ return self.optimizer, self.scheduler
162
+
163
+ def save_checkpoint(self, epoch: int, loss: float, is_best: bool = False):
164
+ """Save complete model checkpoint with all training state."""
165
+ checkpoint_data = {
166
+ 'epoch': epoch,
167
+ 'total_steps': self.total_steps,
168
+ 'model_state_dict': self.model.state_dict(),
169
+ 'optimizer_state_dict': self.optimizer.state_dict(),
170
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
171
+ 'loss': loss,
172
+ 'best_loss': self.best_loss,
173
+ 'config': self.config,
174
+ 'training_history': self.training_history,
175
+ 'timestamp': datetime.now().isoformat(),
176
+ 'model_config': self.model._current_params() # Save model hyperparams
177
+ }
178
+
179
+ # Save latest checkpoint
180
+ latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
181
+ torch.save(checkpoint_data, latest_path)
182
+ logger.info(f"Saved checkpoint: {latest_path}")
183
+
184
+ # Save epoch-specific checkpoint
185
+ epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
186
+ torch.save(checkpoint_data, epoch_path)
187
+
188
+ # Save best model if this is the best loss
189
+ if is_best:
190
+ best_path = self.checkpoint_dir / 'checkpoint_best.pt'
191
+ torch.save(checkpoint_data, best_path)
192
+ logger.info(f"🏆 NEW BEST MODEL! Loss: {loss:.6f} -> {best_path}")
193
+
194
+ # Save training config for reference
195
+ config_path = self.checkpoint_dir / 'training_config.json'
196
+ with open(config_path, 'w') as f:
197
+ json.dump(self.config, f, indent=2)
198
+
199
+ def load_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool:
200
+ """Load checkpoint if available and resume training."""
201
+ if checkpoint_path is None:
202
+ checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt'
203
+
204
+ checkpoint_path = Path(checkpoint_path)
205
+ if not checkpoint_path.exists():
206
+ logger.info("No checkpoint found - starting fresh training")
207
+ return False
208
+
209
+ logger.info(f"Loading checkpoint: {checkpoint_path}")
210
+ try:
211
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
212
+
213
+ # Load model state
214
+ self.model.load_state_dict(checkpoint['model_state_dict'])
215
+
216
+ # Load optimizer state
217
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
218
+
219
+ # Load scheduler state
220
+ if self.scheduler and checkpoint.get('scheduler_state_dict'):
221
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
222
+
223
+ # Load training state
224
+ self.current_epoch = checkpoint['epoch']
225
+ self.total_steps = checkpoint['total_steps']
226
+ self.best_loss = checkpoint['best_loss']
227
+ self.training_history = checkpoint.get('training_history', [])
228
+
229
+ logger.info(f"✅ Resumed from epoch {self.current_epoch}, best loss: {self.best_loss:.6f}")
230
+ logger.info(f"Total steps completed: {self.total_steps}")
231
+ return True
232
+
233
+ except Exception as e:
234
+ logger.error(f"Failed to load checkpoint: {e}")
235
+ return False
236
+
237
+ def training_step(self, batch: torch.Tensor) -> Dict[str, float]:
238
+ """Single training step following the ACTUAL model pattern."""
239
+ batch = batch.to(self.device)
240
+
241
+ # Zero gradients
242
+ self.optimizer.zero_grad()
243
+
244
+ # Forward pass - EXACTLY like the working basic_training.py
245
+ logits, telemetry = self.model(batch)
246
+
247
+ # Loss calculation - EXACTLY like example_training_step
248
+ pred = logits[:, :-1, :].reshape(-1, 2)
249
+ target = batch[:, 1:].reshape(-1)
250
+ loss = F.cross_entropy(pred, target)
251
+
252
+ # Backward pass
253
+ loss.backward()
254
+
255
+ # Gradient clipping
256
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])
257
+
258
+ # Optimizer step
259
+ self.optimizer.step()
260
+ if self.scheduler:
261
+ self.scheduler.step()
262
+
263
+ self.total_steps += 1
264
+
265
+ # Extract telemetry values properly
266
+ metrics = {'loss': loss.item()}
267
+ if telemetry:
268
+ for key, value in telemetry.items():
269
+ if torch.is_tensor(value):
270
+ metrics[key] = value.mean().item()
271
+ else:
272
+ metrics[key] = value
273
+
274
+ return metrics
275
+
276
+ def train_epoch(self) -> Dict[str, float]:
277
+ """Train for one complete epoch."""
278
+ logger.info(f"Starting epoch {self.current_epoch + 1}")
279
+
280
+ # Use EXACT same pattern as working basic_training.py
281
+ self.model.train()
282
+ epoch_losses = []
283
+
284
+ # Simple batching - EXACTLY like working basic_training.py
285
+ batch_size = self.config['batch_size']
286
+ for i in range(0, len(self.dataset), batch_size):
287
+ batch = self.dataset[i:i + batch_size]
288
+ if len(batch) < batch_size:
289
+ continue # Skip incomplete batches
290
+
291
+ batch = batch.to(self.device)
292
+
293
+ # Zero gradients
294
+ self.optimizer.zero_grad()
295
+
296
+ # Forward pass - EXACTLY like working basic_training.py
297
+ logits, telemetry = self.model(batch)
298
+
299
+ # Loss calculation - EXACTLY like working basic_training.py
300
+ pred = logits[:, :-1, :].reshape(-1, 2)
301
+ target = batch[:, 1:].reshape(-1)
302
+ loss = F.cross_entropy(pred, target)
303
+
304
+ # Backward pass
305
+ loss.backward()
306
+
307
+ # Gradient clipping
308
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])
309
+
310
+ # Optimizer step
311
+ self.optimizer.step()
312
+ if self.scheduler:
313
+ self.scheduler.step()
314
+
315
+ self.total_steps += 1
316
+ epoch_losses.append(loss.item())
317
+
318
+ # Calculate epoch averages - simplified like basic_training.py
319
+ avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf')
320
+
321
+ epoch_summary = {
322
+ 'epoch': self.current_epoch + 1,
323
+ 'avg_loss': avg_loss
324
+ }
325
+
326
+ self.training_history.append(epoch_summary)
327
+
328
+ logger.info(
329
+ f"Epoch {self.current_epoch + 1} completed: "
330
+ f"Avg Loss={avg_loss:.6f}"
331
+ )
332
+
333
+ return epoch_summary
334
+
335
+ def train(self):
336
+ """Main training loop."""
337
+ logger.info("🚀 STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
338
+ logger.info("Configuration: Fixed RL Adafactor + 16M parameters + CPU training")
339
+
340
+ start_epoch = self.current_epoch
341
+
342
+ for epoch in range(start_epoch, self.config['num_epochs']):
343
+ try:
344
+ # Train epoch
345
+ epoch_metrics = self.train_epoch()
346
+ avg_loss = epoch_metrics['avg_loss']
347
+
348
+ # Check if this is the best model
349
+ is_best = avg_loss < self.best_loss
350
+ if is_best:
351
+ self.best_loss = avg_loss
352
+
353
+ # Save checkpoint after each epoch
354
+ self.save_checkpoint(self.current_epoch + 1, avg_loss, is_best)
355
+
356
+ self.current_epoch += 1
357
+
358
+ # Log progress
359
+ logger.info(f"=== EPOCH {self.current_epoch} COMPLETE ===")
360
+ logger.info(f"Loss: {avg_loss:.6f} (best: {self.best_loss:.6f})")
361
+
362
+ # Check for breakthrough performance (loss < 3.0)
363
+ if avg_loss < 3.0:
364
+ logger.info("🚀 BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
365
+
366
+ except KeyboardInterrupt:
367
+ logger.info("Training interrupted by user")
368
+ # Save checkpoint before exiting
369
+ try:
370
+ self.save_checkpoint(self.current_epoch, float('inf'), False)
371
+ except:
372
+ pass
373
+ break
374
+ except Exception as e:
375
+ logger.error(f"Error in epoch {self.current_epoch + 1}: {e}")
376
+ # Save emergency checkpoint
377
+ try:
378
+ self.save_checkpoint(self.current_epoch, float('inf'), False)
379
+ except:
380
+ pass
381
+ raise
382
+
383
+ def main():
384
+ """Main function to run breakthrough training."""
385
+
386
+ # BREAKTHROUGH TRAINING CONFIGURATION
387
+ config = {
388
+ # Model parameters (breakthrough configuration)
389
+ 'sequence_length': 512,
390
+
391
+ # Training parameters
392
+ 'learning_rate': 1e-3, # FIXED LR - key to breakthrough!
393
+ 'weight_decay': 0.01,
394
+ 'batch_size': 4, # Adjust based on memory
395
+ 'num_epochs': 50, # Full training run
396
+ 'max_grad_norm': 1.0,
397
+
398
+ # Data parameters
399
+ 'hf_token': None, # Set via environment variable HF_TOKEN
400
+
401
+ # Logging and checkpointing
402
+ 'log_interval': 100,
403
+ 'checkpoint_dir': '/data/BitTransformerLM/checkpoints',
404
+ }
405
+
406
+ # Create trainer
407
+ trainer = BreakthroughTrainer(config)
408
+
409
+ # Setup all components
410
+ logger.info("Setting up training components...")
411
+ trainer.load_and_prepare_dataset()
412
+ trainer.create_breakthrough_model()
413
+ trainer.setup_optimizer()
414
+
415
+ # Try to resume from checkpoint
416
+ trainer.load_checkpoint()
417
+
418
+ # Start training
419
+ trainer.train()
420
+
421
+ logger.info("🎉 BREAKTHROUGH TRAINING COMPLETED!")
422
+ logger.info(f"Best loss achieved: {trainer.best_loss:.6f}")
423
+ logger.info(f"Checkpoints saved to: {trainer.checkpoint_dir}")
424
+
425
+ if __name__ == "__main__":
426
+ main()
scripts/training/full_attention_training.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BitTransformerLM Full Bi-Directional Attention Training Script
4
+ ===============================================================
5
+
6
+ This script implements the breakthrough Fixed RL Adafactor training configuration
7
+ for production-scale BitTransformerLM training with FULL BI-DIRECTIONAL UNCHUNKED ATTENTION.
8
+
9
+ Configuration:
10
+ - Model: 16M parameters (d_model=512, nhead=16, num_layers=8)
11
+ - Attention: FULL BI-DIRECTIONAL UNCHUNKED (chunk_size=None)
12
+ - Optimizer: Fixed LR Adafactor (identical to breakthrough config)
13
+ - Features: Reversible layers, ACT, QAT, compression
14
+ - Data: HuggingFace WCNegentropy/BitTransformerLM dataset
15
+ - Checkpointing: After every training cycle for continuous training
16
+ """
17
+
18
+ import sys
19
+ import os
20
+ import json
21
+ import time
22
+ import logging
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ from typing import Optional, Dict, Any
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from datasets import load_dataset
30
+ from huggingface_hub import login
31
+
32
+ # Add paths for imports
33
+ sys.path.append('/data')
34
+ sys.path.append('/data/BitTransformerLM')
35
+
36
+ from bit_transformer import (
37
+ BitTransformerLM,
38
+ text_to_bits,
39
+ bits_to_text,
40
+ save_model,
41
+ load_model,
42
+ set_dropout
43
+ )
44
+ from BTLM_Extensions import configure_adafactor_optimizer
45
+
46
+ # Setup logging
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format='%(asctime)s - %(levelname)s - %(message)s',
50
+ handlers=[
51
+ logging.FileHandler('full_attention_training.log'),
52
+ logging.StreamHandler()
53
+ ]
54
+ )
55
+ logger = logging.getLogger(__name__)
56
+
57
+ class ProductionTrainer:
58
+ """Production-grade BitTransformerLM trainer with breakthrough configuration."""
59
+
60
+ def __init__(self, config: Dict[str, Any]):
61
+ self.config = config
62
+ self.device = torch.device('cpu') # CPU training as per breakthrough
63
+ self.model = None
64
+ self.optimizer = None
65
+ self.scheduler = None
66
+ self.dataset = None
67
+ self.checkpoint_dir = Path(config['checkpoint_dir'])
68
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ # Training state
71
+ self.current_epoch = 0
72
+ self.total_steps = 0
73
+ self.best_loss = float('inf')
74
+ self.training_history = []
75
+
76
+ def setup_model(self):
77
+ """Create the breakthrough 16M parameter BitTransformerLM model with full bi-directional attention."""
78
+ logger.info("Setting up breakthrough BitTransformerLM with FULL BI-DIRECTIONAL UNCHUNKED ATTENTION...")
79
+
80
+ self.model = BitTransformerLM(
81
+ d_model=512, # Breakthrough config
82
+ nhead=16, # 16 attention heads
83
+ num_layers=8, # 8 layers for ~16M params
84
+ dim_feedforward=1024, # 2x d_model for optimal params
85
+ max_seq_len=512, # Reasonable sequence length
86
+ reversible=True, # Memory efficiency
87
+ use_checkpoint=True, # Gradient checkpointing
88
+ use_autocast=True, # CPU mixed precision
89
+ use_act=True, # Adaptive Computation Time
90
+ act_threshold=0.9, # ACT threshold
91
+ lambda_K=0.05, # Safety telemetry weights
92
+ lambda_C=0.05,
93
+ lambda_S=0.05,
94
+ chunk_size=None, # FULL ATTENTION - no chunking
95
+ overlap=0, # No overlap needed for full attention
96
+ full_attn_logging=True # Enable full attention logging
97
+ ).to(self.device)
98
+
99
+ # Calculate actual parameter count
100
+ total_params = sum(p.numel() for p in self.model.parameters())
101
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
102
+
103
+ logger.info(f"Model created: {total_params:,} total parameters ({trainable_params:,} trainable)")
104
+ logger.info(f"Target: ~16M parameters - {'✓' if 15_000_000 <= total_params <= 17_000_000 else '✗'}")
105
+
106
+ return self.model
107
+
108
+ def setup_optimizer(self):
109
+ """Setup Fixed RL Adafactor optimizer (the breakthrough secret sauce)."""
110
+ logger.info("Setting up Fixed RL Adafactor optimizer...")
111
+
112
+ # CRITICAL: Use fixed LR, not auto-LR (lr=None)
113
+ self.optimizer, self.scheduler = configure_adafactor_optimizer(
114
+ self.model,
115
+ lr=self.config['learning_rate'], # Fixed LR - the key to breakthrough!
116
+ weight_decay=self.config['weight_decay'],
117
+ total_steps=self.config['total_steps']
118
+ )
119
+
120
+ logger.info(f"Fixed RL Adafactor configured with LR={self.config['learning_rate']}")
121
+ return self.optimizer, self.scheduler
122
+
123
+ def setup_dataset(self):
124
+ """Load and prepare the WCNegentropy/BitTransformerLM dataset."""
125
+ logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
126
+
127
+ # Login to HuggingFace
128
+ login(token=self.config['hf_token'])
129
+
130
+ # Load dataset
131
+ try:
132
+ dataset = load_dataset("WCNegentropy/BitTransformerLM")
133
+ logger.info(f"Dataset loaded: {dataset}")
134
+
135
+ # Use train split
136
+ train_data = dataset['train'] if 'train' in dataset else dataset
137
+ logger.info(f"Training samples: {len(train_data)}")
138
+
139
+ # Process dataset - convert to bits using the ACTUAL text_to_bits function
140
+ bit_sequences = []
141
+ for i, sample in enumerate(train_data):
142
+ if i % 1000 == 0:
143
+ logger.info(f"Processing sample {i}/{len(train_data)}")
144
+
145
+ # Try to get text from various fields
146
+ text = None
147
+ if 'original_text' in sample and sample['original_text']:
148
+ text = sample['original_text']
149
+ elif 'text' in sample and sample['text']:
150
+ text = sample['text']
151
+
152
+ if text and text.strip():
153
+ # Use ACTUAL text_to_bits function
154
+ bits = text_to_bits(text)
155
+ if len(bits) >= self.config['sequence_length']:
156
+ bit_sequences.append(bits)
157
+
158
+ logger.info(f"Processed {len(bit_sequences)} valid bit sequences")
159
+
160
+ # Create training sequences with proper length
161
+ seq_len = self.config['sequence_length']
162
+ training_sequences = []
163
+
164
+ for bits in bit_sequences:
165
+ # Create overlapping chunks
166
+ for i in range(0, len(bits) - seq_len + 1, seq_len // 2):
167
+ chunk = bits[i:i + seq_len]
168
+ if len(chunk) == seq_len:
169
+ training_sequences.append(chunk)
170
+
171
+ # Convert to tensor with proper dtype
172
+ self.dataset = torch.tensor(training_sequences, dtype=torch.long)
173
+ logger.info(f"Created training dataset: {self.dataset.shape}")
174
+
175
+ except Exception as e:
176
+ logger.error(f"Failed to load dataset: {e}")
177
+ # Fallback to synthetic data for testing
178
+ logger.info("Falling back to synthetic bit data...")
179
+ synthetic_bits = torch.randint(0, 2, (1000, self.config['sequence_length']))
180
+ self.dataset = synthetic_bits
181
+ logger.warning("Using synthetic data - replace with real dataset for production")
182
+
183
+ return self.dataset
184
+
185
+ def save_checkpoint(self, epoch: int, loss: float, is_best: bool = False):
186
+ """Save model checkpoint with all training state."""
187
+ checkpoint_data = {
188
+ 'epoch': epoch,
189
+ 'total_steps': self.total_steps,
190
+ 'model_state_dict': self.model.state_dict(),
191
+ 'optimizer_state_dict': self.optimizer.state_dict(),
192
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
193
+ 'loss': loss,
194
+ 'best_loss': self.best_loss,
195
+ 'config': self.config,
196
+ 'training_history': self.training_history,
197
+ 'timestamp': datetime.now().isoformat()
198
+ }
199
+
200
+ # Save latest checkpoint
201
+ latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
202
+ torch.save(checkpoint_data, latest_path)
203
+ logger.info(f"Saved checkpoint: {latest_path}")
204
+
205
+ # Save epoch-specific checkpoint
206
+ epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
207
+ torch.save(checkpoint_data, epoch_path)
208
+
209
+ # Save best model if this is the best loss
210
+ if is_best:
211
+ best_path = self.checkpoint_dir / 'checkpoint_best.pt'
212
+ torch.save(checkpoint_data, best_path)
213
+ logger.info(f"NEW BEST MODEL! Loss: {loss:.6f} -> {best_path}")
214
+
215
+ # Save training config for reference
216
+ config_path = self.checkpoint_dir / 'training_config.json'
217
+ with open(config_path, 'w') as f:
218
+ json.dump(self.config, f, indent=2)
219
+
220
+ def load_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool:
221
+ """Load model weights from latest checkpoint but restart training from epoch 1."""
222
+ if checkpoint_path is None:
223
+ checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt'
224
+
225
+ checkpoint_path = Path(checkpoint_path)
226
+ if not checkpoint_path.exists():
227
+ logger.info("No checkpoint found - starting fresh training")
228
+ return False
229
+
230
+ logger.info(f"Loading model weights from: {checkpoint_path}")
231
+ try:
232
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
233
+
234
+ # Load ONLY model weights
235
+ self.model.load_state_dict(checkpoint['model_state_dict'])
236
+
237
+ # RESET all training state to start from epoch 1
238
+ self.current_epoch = 1
239
+ self.total_steps = 0
240
+ self.best_loss = float('inf')
241
+ self.training_history = []
242
+
243
+ # DO NOT load optimizer/scheduler state - fresh start
244
+
245
+ logger.info(f"Loaded model weights, restarting training from epoch 1, step 0")
246
+ return True
247
+
248
+ except Exception as e:
249
+ logger.error(f"Failed to load checkpoint: {e}")
250
+ return False
251
+
252
+ def training_step(self, batch: torch.Tensor) -> Dict[str, float]:
253
+ """Single training step with telemetry."""
254
+ self.model.train()
255
+ set_dropout(self.model, self.config['dropout'])
256
+
257
+ batch = batch.to(self.device)
258
+
259
+ # Zero gradients
260
+ self.optimizer.zero_grad()
261
+
262
+ # Forward pass with telemetry
263
+ with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
264
+ logits, telemetry = self.model(batch)
265
+
266
+ # Compute loss (next bit prediction)
267
+ if logits.dim() == 3: # (batch, seq, 2)
268
+ targets = batch[:, 1:] # Next bit prediction
269
+ logits = logits[:, :-1] # Remove last prediction
270
+ loss = F.cross_entropy(logits.reshape(-1, 2), targets.reshape(-1))
271
+ else:
272
+ loss = F.cross_entropy(logits, batch)
273
+
274
+ # Add telemetry regularization (safety metrics)
275
+ if self.model.lambda_K > 0 and 'negentropy_logits' in telemetry:
276
+ k_term = self.model.lambda_K * (1 - telemetry['negentropy_logits'])
277
+ if k_term.dim() == 0: # scalar
278
+ loss = loss + k_term
279
+ else:
280
+ loss = loss + k_term.mean()
281
+ if self.model.lambda_C > 0 and 'lz_complexity_logits' in telemetry:
282
+ c_term = self.model.lambda_C * (1 - telemetry['lz_complexity_logits'])
283
+ if c_term.dim() == 0: # scalar
284
+ loss = loss + c_term
285
+ else:
286
+ loss = loss + c_term.mean()
287
+ if self.model.lambda_S > 0 and 'symbiosis_score' in telemetry:
288
+ s_term = self.model.lambda_S * (1 - telemetry['symbiosis_score'])
289
+ if s_term.dim() == 0: # scalar
290
+ loss = loss + s_term
291
+ else:
292
+ loss = loss + s_term.mean()
293
+
294
+ # Backward pass
295
+ loss.backward()
296
+
297
+ # Gradient clipping
298
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])
299
+
300
+ # Optimizer step
301
+ self.optimizer.step()
302
+ if self.scheduler:
303
+ self.scheduler.step()
304
+
305
+ self.total_steps += 1
306
+
307
+ return {
308
+ 'loss': loss.item(),
309
+ 'K': telemetry.get('negentropy_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('negentropy_logits', 0.0)) else telemetry.get('negentropy_logits', 0.0),
310
+ 'C': telemetry.get('lz_complexity_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('lz_complexity_logits', 0.0)) else telemetry.get('lz_complexity_logits', 0.0),
311
+ 'S': telemetry.get('symbiosis_score', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('symbiosis_score', 0.0)) else telemetry.get('symbiosis_score', 0.0),
312
+ 'lr': self.optimizer.param_groups[0]['lr']
313
+ }
314
+
315
+ def train_epoch(self) -> Dict[str, float]:
316
+ """Train for one epoch."""
317
+ logger.info(f"Starting epoch {self.current_epoch + 1}")
318
+
319
+ # Create data loader
320
+ from torch.utils.data import DataLoader
321
+ dataloader = DataLoader(
322
+ self.dataset,
323
+ batch_size=self.config['batch_size'],
324
+ shuffle=True,
325
+ drop_last=True
326
+ )
327
+
328
+ epoch_losses = []
329
+ epoch_metrics = {'K': [], 'C': [], 'S': []}
330
+
331
+ start_time = time.time()
332
+
333
+ for step, batch in enumerate(dataloader):
334
+ metrics = self.training_step(batch)
335
+
336
+ epoch_losses.append(metrics['loss'])
337
+ epoch_metrics['K'].append(metrics['K'])
338
+ epoch_metrics['C'].append(metrics['C'])
339
+ epoch_metrics['S'].append(metrics['S'])
340
+
341
+ # Log progress
342
+ if step % self.config['log_interval'] == 0:
343
+ logger.info(
344
+ f"Epoch {self.current_epoch + 1}, Step {step}/{len(dataloader)}: "
345
+ f"Loss={metrics['loss']:.6f}, K={metrics['K']:.3f}, "
346
+ f"C={metrics['C']:.3f}, S={metrics['S']:.3f}, LR={metrics['lr']:.2e}"
347
+ )
348
+
349
+ # Calculate epoch metrics
350
+ epoch_time = time.time() - start_time
351
+ avg_loss = sum(epoch_losses) / len(epoch_losses)
352
+ avg_metrics = {k: sum(v) / len(v) for k, v in epoch_metrics.items()}
353
+
354
+ epoch_summary = {
355
+ 'epoch': self.current_epoch + 1,
356
+ 'avg_loss': avg_loss,
357
+ 'time': epoch_time,
358
+ **avg_metrics
359
+ }
360
+
361
+ self.training_history.append(epoch_summary)
362
+
363
+ logger.info(
364
+ f"Epoch {self.current_epoch + 1} completed in {epoch_time:.1f}s: "
365
+ f"Avg Loss={avg_loss:.6f}, K={avg_metrics['K']:.3f}, "
366
+ f"C={avg_metrics['C']:.3f}, S={avg_metrics['S']:.3f}"
367
+ )
368
+
369
+ return epoch_summary
370
+
371
+ def train(self, num_epochs: int):
372
+ """Main training loop."""
373
+ logger.info(f"Starting production training for {num_epochs} epochs...")
374
+ logger.info(f"Breakthrough configuration: Fixed RL Adafactor + 16M BitTransformerLM")
375
+
376
+ for epoch in range(num_epochs):
377
+ try:
378
+ # Train epoch
379
+ epoch_metrics = self.train_epoch()
380
+ avg_loss = epoch_metrics['avg_loss']
381
+
382
+ # Check if this is the best model
383
+ is_best = avg_loss < self.best_loss
384
+ if is_best:
385
+ self.best_loss = avg_loss
386
+
387
+ # Save checkpoint after each epoch
388
+ self.save_checkpoint(self.current_epoch + 1, avg_loss, is_best)
389
+
390
+ self.current_epoch += 1
391
+
392
+ # Log progress
393
+ logger.info(f"=== EPOCH {self.current_epoch} COMPLETE ===")
394
+ logger.info(f"Loss: {avg_loss:.6f} (best: {self.best_loss:.6f})")
395
+
396
+ # Check for breakthrough performance (loss < 3.0)
397
+ if avg_loss < 3.0:
398
+ logger.info("🚀 BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
399
+
400
+ except KeyboardInterrupt:
401
+ logger.info("Training interrupted by user")
402
+ break
403
+ except Exception as e:
404
+ logger.error(f"Error in epoch {self.current_epoch + 1}: {e}")
405
+ # Save emergency checkpoint
406
+ try:
407
+ self.save_checkpoint(self.current_epoch, float('inf'), False)
408
+ except:
409
+ pass
410
+ raise
411
+
412
+
413
+ def main():
414
+ """Main function to run production training."""
415
+
416
+ # Production training configuration
417
+ config = {
418
+ # Model parameters (breakthrough configuration)
419
+ 'model_params': {
420
+ 'd_model': 512,
421
+ 'nhead': 16,
422
+ 'num_layers': 8,
423
+ 'dim_feedforward': 1024,
424
+ },
425
+
426
+ # Training parameters
427
+ 'learning_rate': 1e-3, # FIXED LR - key to breakthrough!
428
+ 'weight_decay': 0.01,
429
+ 'batch_size': 4, # Adjust based on memory
430
+ 'sequence_length': 256, # Bit sequence length
431
+ 'num_epochs': 50, # Long training run
432
+ 'max_grad_norm': 1.0,
433
+ 'dropout': 0.1,
434
+ 'total_steps': 10000, # For scheduler
435
+
436
+ # Data parameters
437
+ 'hf_token': None, # Set via environment variable HF_TOKEN
438
+
439
+ # Logging and checkpointing
440
+ 'log_interval': 10,
441
+ 'checkpoint_dir': '/data/BitTransformerLM/checkpoints',
442
+ }
443
+
444
+ # Create trainer
445
+ trainer = ProductionTrainer(config)
446
+
447
+ # Setup components
448
+ trainer.setup_model()
449
+ trainer.setup_optimizer()
450
+ trainer.setup_dataset()
451
+
452
+ # Try to resume from checkpoint
453
+ trainer.load_checkpoint()
454
+
455
+ # Start training
456
+ logger.info("🚀 STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
457
+ logger.info("Configuration: Fixed RL Adafactor + 16M parameters + CPU training")
458
+
459
+ trainer.train(config['num_epochs'])
460
+
461
+ logger.info("Training completed!")
462
+ logger.info(f"Best loss achieved: {trainer.best_loss:.6f}")
463
+ logger.info(f"Checkpoints saved to: {trainer.checkpoint_dir}")
464
+
465
+
466
+ if __name__ == "__main__":
467
+ main()
scripts/training/production_training.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BitTransformerLM Production Training Script
4
+ ==========================================
5
+
6
+ This script implements the breakthrough Fixed RL Adafactor training configuration
7
+ for production-scale BitTransformerLM training with continuous checkpointing.
8
+
9
+ Configuration:
10
+ - Model: 16M parameters (d_model=512, nhead=16, num_layers=8)
11
+ - Optimizer: Fixed LR Adafactor (not auto-LR)
12
+ - Features: Reversible layers, ACT, QAT, compression
13
+ - Data: HuggingFace WCNegentropy/BitTransformerLM dataset
14
+ - Checkpointing: After every training cycle for continuous training
15
+ """
16
+
17
+ import sys
18
+ import os
19
+ import json
20
+ import time
21
+ import logging
22
+ from datetime import datetime
23
+ from pathlib import Path
24
+ from typing import Optional, Dict, Any
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from datasets import load_dataset
29
+ from huggingface_hub import login
30
+
31
+ # Add paths for imports
32
+ sys.path.append('/data')
33
+ sys.path.append('/data/BitTransformerLM')
34
+
35
+ from bit_transformer import (
36
+ BitTransformerLM,
37
+ text_to_bits,
38
+ bits_to_text,
39
+ save_model,
40
+ load_model,
41
+ set_dropout
42
+ )
43
+ from BTLM_Extensions import configure_adafactor_optimizer
44
+
45
+ # Setup logging
46
+ logging.basicConfig(
47
+ level=logging.INFO,
48
+ format='%(asctime)s - %(levelname)s - %(message)s',
49
+ handlers=[
50
+ logging.FileHandler('production_training.log'),
51
+ logging.StreamHandler()
52
+ ]
53
+ )
54
+ logger = logging.getLogger(__name__)
55
+
56
+ class ProductionTrainer:
57
+ """Production-grade BitTransformerLM trainer with breakthrough configuration."""
58
+
59
+ def __init__(self, config: Dict[str, Any]):
60
+ self.config = config
61
+ self.device = torch.device('cpu') # CPU training as per breakthrough
62
+ self.model = None
63
+ self.optimizer = None
64
+ self.scheduler = None
65
+ self.dataset = None
66
+ self.checkpoint_dir = Path(config['checkpoint_dir'])
67
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
68
+
69
+ # Training state
70
+ self.current_epoch = 0
71
+ self.total_steps = 0
72
+ self.best_loss = float('inf')
73
+ self.training_history = []
74
+
75
+ def setup_model(self):
76
+ """Create the breakthrough 16M parameter BitTransformerLM model."""
77
+ logger.info("Setting up breakthrough BitTransformerLM configuration...")
78
+
79
+ self.model = BitTransformerLM(
80
+ d_model=512, # Breakthrough config
81
+ nhead=16, # 16 attention heads
82
+ num_layers=8, # 8 layers for ~16M params
83
+ dim_feedforward=1024, # 2x d_model for optimal params
84
+ max_seq_len=512, # Reasonable sequence length
85
+ reversible=True, # Memory efficiency
86
+ use_checkpoint=True, # Gradient checkpointing
87
+ use_autocast=True, # CPU mixed precision
88
+ use_act=True, # Adaptive Computation Time
89
+ act_threshold=0.9, # ACT threshold
90
+ lambda_K=0.05, # Safety telemetry weights
91
+ lambda_C=0.05,
92
+ lambda_S=0.05
93
+ ).to(self.device)
94
+
95
+ # Calculate actual parameter count
96
+ total_params = sum(p.numel() for p in self.model.parameters())
97
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
98
+
99
+ logger.info(f"Model created: {total_params:,} total parameters ({trainable_params:,} trainable)")
100
+ logger.info(f"Target: ~16M parameters - {'✓' if 15_000_000 <= total_params <= 17_000_000 else '✗'}")
101
+
102
+ return self.model
103
+
104
+ def setup_optimizer(self):
105
+ """Setup Fixed RL Adafactor optimizer (the breakthrough secret sauce)."""
106
+ logger.info("Setting up Fixed RL Adafactor optimizer...")
107
+
108
+ # CRITICAL: Use fixed LR, not auto-LR (lr=None)
109
+ self.optimizer, self.scheduler = configure_adafactor_optimizer(
110
+ self.model,
111
+ lr=self.config['learning_rate'], # Fixed LR - the key to breakthrough!
112
+ weight_decay=self.config['weight_decay'],
113
+ total_steps=self.config['total_steps']
114
+ )
115
+
116
+ logger.info(f"Fixed RL Adafactor configured with LR={self.config['learning_rate']}")
117
+ return self.optimizer, self.scheduler
118
+
119
+ def setup_dataset(self):
120
+ """Load and prepare the WCNegentropy/BitTransformerLM dataset."""
121
+ logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
122
+
123
+ # Login to HuggingFace
124
+ login(token=self.config['hf_token'])
125
+
126
+ # Load dataset
127
+ try:
128
+ dataset = load_dataset("WCNegentropy/BitTransformerLM")
129
+ logger.info(f"Dataset loaded: {dataset}")
130
+
131
+ # Use train split
132
+ train_data = dataset['train'] if 'train' in dataset else dataset
133
+ logger.info(f"Training samples: {len(train_data)}")
134
+
135
+ # Process dataset - convert to bits using the ACTUAL text_to_bits function
136
+ bit_sequences = []
137
+ for i, sample in enumerate(train_data):
138
+ if i % 1000 == 0:
139
+ logger.info(f"Processing sample {i}/{len(train_data)}")
140
+
141
+ # Try to get text from various fields
142
+ text = None
143
+ if 'original_text' in sample and sample['original_text']:
144
+ text = sample['original_text']
145
+ elif 'text' in sample and sample['text']:
146
+ text = sample['text']
147
+
148
+ if text and text.strip():
149
+ # Use ACTUAL text_to_bits function
150
+ bits = text_to_bits(text)
151
+ if len(bits) >= self.config['sequence_length']:
152
+ bit_sequences.append(bits)
153
+
154
+ logger.info(f"Processed {len(bit_sequences)} valid bit sequences")
155
+
156
+ # Create training sequences with proper length
157
+ seq_len = self.config['sequence_length']
158
+ training_sequences = []
159
+
160
+ for bits in bit_sequences:
161
+ # Create overlapping chunks
162
+ for i in range(0, len(bits) - seq_len + 1, seq_len // 2):
163
+ chunk = bits[i:i + seq_len]
164
+ if len(chunk) == seq_len:
165
+ training_sequences.append(chunk)
166
+
167
+ # Convert to tensor with proper dtype
168
+ self.dataset = torch.tensor(training_sequences, dtype=torch.long)
169
+ logger.info(f"Created training dataset: {self.dataset.shape}")
170
+
171
+ except Exception as e:
172
+ logger.error(f"Failed to load dataset: {e}")
173
+ # Fallback to synthetic data for testing
174
+ logger.info("Falling back to synthetic bit data...")
175
+ synthetic_bits = torch.randint(0, 2, (1000, self.config['sequence_length']))
176
+ self.dataset = synthetic_bits
177
+ logger.warning("Using synthetic data - replace with real dataset for production")
178
+
179
+ return self.dataset
180
+
181
+ def save_checkpoint(self, epoch: int, loss: float, is_best: bool = False):
182
+ """Save model checkpoint with all training state."""
183
+ checkpoint_data = {
184
+ 'epoch': epoch,
185
+ 'total_steps': self.total_steps,
186
+ 'model_state_dict': self.model.state_dict(),
187
+ 'optimizer_state_dict': self.optimizer.state_dict(),
188
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
189
+ 'loss': loss,
190
+ 'best_loss': self.best_loss,
191
+ 'config': self.config,
192
+ 'training_history': self.training_history,
193
+ 'timestamp': datetime.now().isoformat()
194
+ }
195
+
196
+ # Save latest checkpoint
197
+ latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
198
+ torch.save(checkpoint_data, latest_path)
199
+ logger.info(f"Saved checkpoint: {latest_path}")
200
+
201
+ # Save epoch-specific checkpoint
202
+ epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
203
+ torch.save(checkpoint_data, epoch_path)
204
+
205
+ # Save best model if this is the best loss
206
+ if is_best:
207
+ best_path = self.checkpoint_dir / 'checkpoint_best.pt'
208
+ torch.save(checkpoint_data, best_path)
209
+ logger.info(f"NEW BEST MODEL! Loss: {loss:.6f} -> {best_path}")
210
+
211
+ # Save training config for reference
212
+ config_path = self.checkpoint_dir / 'training_config.json'
213
+ with open(config_path, 'w') as f:
214
+ json.dump(self.config, f, indent=2)
215
+
216
+ def load_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool:
217
+ """Load checkpoint if available."""
218
+ if checkpoint_path is None:
219
+ checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt'
220
+
221
+ checkpoint_path = Path(checkpoint_path)
222
+ if not checkpoint_path.exists():
223
+ logger.info("No checkpoint found - starting fresh training")
224
+ return False
225
+
226
+ logger.info(f"Loading checkpoint: {checkpoint_path}")
227
+ try:
228
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
229
+
230
+ self.model.load_state_dict(checkpoint['model_state_dict'])
231
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
232
+ if self.scheduler and checkpoint['scheduler_state_dict']:
233
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
234
+
235
+ self.current_epoch = checkpoint['epoch']
236
+ self.total_steps = checkpoint['total_steps']
237
+ self.best_loss = checkpoint['best_loss']
238
+ self.training_history = checkpoint.get('training_history', [])
239
+
240
+ logger.info(f"Resumed from epoch {self.current_epoch}, best loss: {self.best_loss:.6f}")
241
+ return True
242
+
243
+ except Exception as e:
244
+ logger.error(f"Failed to load checkpoint: {e}")
245
+ return False
246
+
247
+ def training_step(self, batch: torch.Tensor) -> Dict[str, float]:
248
+ """Single training step with telemetry."""
249
+ self.model.train()
250
+ set_dropout(self.model, self.config['dropout'])
251
+
252
+ batch = batch.to(self.device)
253
+
254
+ # Zero gradients
255
+ self.optimizer.zero_grad()
256
+
257
+ # Forward pass with telemetry
258
+ with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
259
+ logits, telemetry = self.model(batch)
260
+
261
+ # Compute loss (next bit prediction)
262
+ if logits.dim() == 3: # (batch, seq, 2)
263
+ targets = batch[:, 1:] # Next bit prediction
264
+ logits = logits[:, :-1] # Remove last prediction
265
+ loss = F.cross_entropy(logits.reshape(-1, 2), targets.reshape(-1))
266
+ else:
267
+ loss = F.cross_entropy(logits, batch)
268
+
269
+ # Add telemetry regularization (safety metrics)
270
+ if self.model.lambda_K > 0 and 'negentropy_logits' in telemetry:
271
+ k_term = self.model.lambda_K * (1 - telemetry['negentropy_logits'])
272
+ if k_term.dim() == 0: # scalar
273
+ loss = loss + k_term
274
+ else:
275
+ loss = loss + k_term.mean()
276
+ if self.model.lambda_C > 0 and 'lz_complexity_logits' in telemetry:
277
+ c_term = self.model.lambda_C * (1 - telemetry['lz_complexity_logits'])
278
+ if c_term.dim() == 0: # scalar
279
+ loss = loss + c_term
280
+ else:
281
+ loss = loss + c_term.mean()
282
+ if self.model.lambda_S > 0 and 'symbiosis_score' in telemetry:
283
+ s_term = self.model.lambda_S * (1 - telemetry['symbiosis_score'])
284
+ if s_term.dim() == 0: # scalar
285
+ loss = loss + s_term
286
+ else:
287
+ loss = loss + s_term.mean()
288
+
289
+ # Backward pass
290
+ loss.backward()
291
+
292
+ # Gradient clipping
293
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])
294
+
295
+ # Optimizer step
296
+ self.optimizer.step()
297
+ if self.scheduler:
298
+ self.scheduler.step()
299
+
300
+ self.total_steps += 1
301
+
302
+ return {
303
+ 'loss': loss.item(),
304
+ 'K': telemetry.get('negentropy_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('negentropy_logits', 0.0)) else telemetry.get('negentropy_logits', 0.0),
305
+ 'C': telemetry.get('lz_complexity_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('lz_complexity_logits', 0.0)) else telemetry.get('lz_complexity_logits', 0.0),
306
+ 'S': telemetry.get('symbiosis_score', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('symbiosis_score', 0.0)) else telemetry.get('symbiosis_score', 0.0),
307
+ 'lr': self.optimizer.param_groups[0]['lr']
308
+ }
309
+
310
+ def train_epoch(self) -> Dict[str, float]:
311
+ """Train for one epoch."""
312
+ logger.info(f"Starting epoch {self.current_epoch + 1}")
313
+
314
+ # Create data loader
315
+ from torch.utils.data import DataLoader
316
+ dataloader = DataLoader(
317
+ self.dataset,
318
+ batch_size=self.config['batch_size'],
319
+ shuffle=True,
320
+ drop_last=True
321
+ )
322
+
323
+ epoch_losses = []
324
+ epoch_metrics = {'K': [], 'C': [], 'S': []}
325
+
326
+ start_time = time.time()
327
+
328
+ for step, batch in enumerate(dataloader):
329
+ metrics = self.training_step(batch)
330
+
331
+ epoch_losses.append(metrics['loss'])
332
+ epoch_metrics['K'].append(metrics['K'])
333
+ epoch_metrics['C'].append(metrics['C'])
334
+ epoch_metrics['S'].append(metrics['S'])
335
+
336
+ # Log progress
337
+ if step % self.config['log_interval'] == 0:
338
+ logger.info(
339
+ f"Epoch {self.current_epoch + 1}, Step {step}/{len(dataloader)}: "
340
+ f"Loss={metrics['loss']:.6f}, K={metrics['K']:.3f}, "
341
+ f"C={metrics['C']:.3f}, S={metrics['S']:.3f}, LR={metrics['lr']:.2e}"
342
+ )
343
+
344
+ # Calculate epoch metrics
345
+ epoch_time = time.time() - start_time
346
+ avg_loss = sum(epoch_losses) / len(epoch_losses)
347
+ avg_metrics = {k: sum(v) / len(v) for k, v in epoch_metrics.items()}
348
+
349
+ epoch_summary = {
350
+ 'epoch': self.current_epoch + 1,
351
+ 'avg_loss': avg_loss,
352
+ 'time': epoch_time,
353
+ **avg_metrics
354
+ }
355
+
356
+ self.training_history.append(epoch_summary)
357
+
358
+ logger.info(
359
+ f"Epoch {self.current_epoch + 1} completed in {epoch_time:.1f}s: "
360
+ f"Avg Loss={avg_loss:.6f}, K={avg_metrics['K']:.3f}, "
361
+ f"C={avg_metrics['C']:.3f}, S={avg_metrics['S']:.3f}"
362
+ )
363
+
364
+ return epoch_summary
365
+
366
+ def train(self, num_epochs: int):
367
+ """Main training loop."""
368
+ logger.info(f"Starting production training for {num_epochs} epochs...")
369
+ logger.info(f"Breakthrough configuration: Fixed RL Adafactor + 16M BitTransformerLM")
370
+
371
+ for epoch in range(num_epochs):
372
+ try:
373
+ # Train epoch
374
+ epoch_metrics = self.train_epoch()
375
+ avg_loss = epoch_metrics['avg_loss']
376
+
377
+ # Check if this is the best model
378
+ is_best = avg_loss < self.best_loss
379
+ if is_best:
380
+ self.best_loss = avg_loss
381
+
382
+ # Save checkpoint after each epoch
383
+ self.save_checkpoint(self.current_epoch + 1, avg_loss, is_best)
384
+
385
+ self.current_epoch += 1
386
+
387
+ # Log progress
388
+ logger.info(f"=== EPOCH {self.current_epoch} COMPLETE ===")
389
+ logger.info(f"Loss: {avg_loss:.6f} (best: {self.best_loss:.6f})")
390
+
391
+ # Check for breakthrough performance (loss < 3.0)
392
+ if avg_loss < 3.0:
393
+ logger.info("🚀 BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
394
+
395
+ except KeyboardInterrupt:
396
+ logger.info("Training interrupted by user")
397
+ break
398
+ except Exception as e:
399
+ logger.error(f"Error in epoch {self.current_epoch + 1}: {e}")
400
+ # Save emergency checkpoint
401
+ try:
402
+ self.save_checkpoint(self.current_epoch, float('inf'), False)
403
+ except:
404
+ pass
405
+ raise
406
+
407
+
408
+ def main():
409
+ """Main function to run production training."""
410
+
411
+ # Production training configuration
412
+ config = {
413
+ # Model parameters (breakthrough configuration)
414
+ 'model_params': {
415
+ 'd_model': 512,
416
+ 'nhead': 16,
417
+ 'num_layers': 8,
418
+ 'dim_feedforward': 1024,
419
+ },
420
+
421
+ # Training parameters
422
+ 'learning_rate': 1e-3, # FIXED LR - key to breakthrough!
423
+ 'weight_decay': 0.01,
424
+ 'batch_size': 4, # Adjust based on memory
425
+ 'sequence_length': 256, # Bit sequence length
426
+ 'num_epochs': 50, # Long training run
427
+ 'max_grad_norm': 1.0,
428
+ 'dropout': 0.1,
429
+ 'total_steps': 10000, # For scheduler
430
+
431
+ # Data parameters
432
+ 'hf_token': None, # Set via environment variable HF_TOKEN
433
+
434
+ # Logging and checkpointing
435
+ 'log_interval': 10,
436
+ 'checkpoint_dir': '/data/BitTransformerLM/checkpoints',
437
+ }
438
+
439
+ # Create trainer
440
+ trainer = ProductionTrainer(config)
441
+
442
+ # Setup components
443
+ trainer.setup_model()
444
+ trainer.setup_optimizer()
445
+ trainer.setup_dataset()
446
+
447
+ # Try to resume from checkpoint
448
+ trainer.load_checkpoint()
449
+
450
+ # Start training
451
+ logger.info("🚀 STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
452
+ logger.info("Configuration: Fixed RL Adafactor + 16M parameters + CPU training")
453
+
454
+ trainer.train(config['num_epochs'])
455
+
456
+ logger.info("Training completed!")
457
+ logger.info(f"Best loss achieved: {trainer.best_loss:.6f}")
458
+ logger.info(f"Checkpoints saved to: {trainer.checkpoint_dir}")
459
+
460
+
461
+ if __name__ == "__main__":
462
+ main()
sync_to_hf.py DELETED
@@ -1,220 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Sync BitTransformerLM repository to HuggingFace Hub for OS launch.
4
- Uploads all cleaned documentation and code with proper commit message.
5
- """
6
-
7
- import os
8
- import logging
9
- from pathlib import Path
10
- from huggingface_hub import HfApi, login
11
- from typing import Optional, List
12
-
13
- # Setup logging
14
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
- logger = logging.getLogger(__name__)
16
-
17
- def sync_repository_to_hf(
18
- repo_id: str = "WCNegentropy/BitTransformerLM",
19
- token: Optional[str] = None,
20
- commit_message: str = "🚀 OS Launch: Clean documentation and refined licensing"
21
- ):
22
- """
23
- Sync the entire cleaned BitTransformerLM repository to HuggingFace Hub.
24
-
25
- Args:
26
- repo_id: HuggingFace repository ID
27
- token: HF token (defaults to HF_TOKEN environment variable)
28
- commit_message: Commit message for the upload
29
- """
30
-
31
- # Get token from environment if not provided
32
- if token is None:
33
- token = os.environ.get('HF_TOKEN')
34
- if not token:
35
- logger.error("HF_TOKEN environment variable not set and no token provided")
36
- return False
37
-
38
- try:
39
- # Login to HuggingFace
40
- login(token=token)
41
- api = HfApi()
42
- logger.info("Successfully authenticated with HuggingFace Hub")
43
-
44
- # Get the repository root directory
45
- repo_root = Path(__file__).parent
46
- logger.info(f"Repository root: {repo_root}")
47
-
48
- # Files and directories to upload (excluding unnecessary files)
49
- include_patterns = [
50
- # Core code
51
- "bit_transformer/**/*.py",
52
- "tests/**/*.py",
53
- "*.py", # Root level Python files
54
-
55
- # Documentation (cleaned)
56
- "README.md",
57
- "MODEL_CARD.md",
58
- "RESEARCH_STATUS.md",
59
- "EMPIRICAL_VALIDATION.md",
60
- "OPEN_SOURCE_LAUNCH.md",
61
- "AGENTS.md",
62
-
63
- # Configuration
64
- "requirements.txt",
65
- "pyproject.toml",
66
- "Dockerfile",
67
- "start.sh",
68
-
69
- # License files (cleaned)
70
- "LICENSE/**/*.txt",
71
- ]
72
-
73
- # Files to exclude
74
- exclude_patterns = [
75
- "__pycache__/**",
76
- "*.pyc",
77
- ".git/**",
78
- ".pytest_cache/**",
79
- "weights/**",
80
- "checkpoints/**",
81
- "*.log",
82
- # Outdated documentation
83
- "BitTransformerLM_full_assessment.md",
84
- "FORENSIC_*.md",
85
- "state_of_the_repo_audit.md",
86
- # Old upload script
87
- "upload_to_hf.py",
88
- ]
89
-
90
- # Get all files to upload
91
- files_to_upload = []
92
- for pattern in include_patterns:
93
- for file_path in repo_root.glob(pattern):
94
- if file_path.is_file():
95
- # Check if file should be excluded
96
- relative_path = file_path.relative_to(repo_root)
97
- should_exclude = any(
98
- relative_path.match(exclude)
99
- for exclude in exclude_patterns
100
- )
101
- if not should_exclude:
102
- files_to_upload.append(file_path)
103
-
104
- logger.info(f"Found {len(files_to_upload)} files to upload")
105
-
106
- # Upload files in batches
107
- uploaded_count = 0
108
- for file_path in files_to_upload:
109
- try:
110
- relative_path = file_path.relative_to(repo_root)
111
- logger.info(f"Uploading: {relative_path}")
112
-
113
- api.upload_file(
114
- path_or_fileobj=str(file_path),
115
- path_in_repo=str(relative_path),
116
- repo_id=repo_id,
117
- repo_type="model",
118
- commit_message=commit_message,
119
- commit_description="""
120
- This OS launch commit includes:
121
-
122
- ✅ **Cleaned Documentation**
123
- - Removed inflated claims and marketing language
124
- - Added honest research status and limitations
125
- - Created professional model card and validation reports
126
- - Streamlined licensing to AGPLv3 + commercial contact
127
-
128
- ✅ **Refined Codebase**
129
- - Complete experimental bit-native transformer implementation
130
- - 57 Python files with comprehensive research framework
131
- - Safety telemetry and monitoring systems
132
- - Distributed training and development tools
133
-
134
- ✅ **Professional Standards**
135
- - Empirical validation of all claims
136
- - Clear experimental vs production distinctions
137
- - Rigorous research methodology requirements
138
- - Community contribution framework
139
-
140
- Ready for serious research evaluation and academic investigation.
141
- """.strip()
142
- )
143
-
144
- uploaded_count += 1
145
- if uploaded_count % 10 == 0:
146
- logger.info(f"Progress: {uploaded_count}/{len(files_to_upload)} files uploaded")
147
-
148
- except Exception as e:
149
- logger.warning(f"Failed to upload {relative_path}: {e}")
150
- continue
151
-
152
- logger.info(f"✅ Successfully uploaded {uploaded_count}/{len(files_to_upload)} files")
153
- logger.info(f"🎉 Repository synced to: https://huggingface.co/{repo_id}")
154
-
155
- return True
156
-
157
- except Exception as e:
158
- logger.error(f"❌ Failed to sync repository: {e}")
159
- return False
160
-
161
- def create_release_info():
162
- """Create a release information file for the OS launch."""
163
- release_info = """# BitTransformerLM v0.1.0 - Experimental Research Release
164
-
165
- **Release Date:** August 2025
166
- **Status:** Open Source Research Implementation
167
- **License:** AGPLv3 + Commercial Licensing Available
168
-
169
- ## What's Included
170
-
171
- This release provides a complete experimental framework for bit-native language modeling research:
172
-
173
- - **Core Architecture:** 57 Python files implementing bit-native transformer with reversible layers
174
- - **Safety Systems:** Real-time K/C/S telemetry and monitoring
175
- - **Research Tools:** Interactive dashboard, distributed training, comprehensive testing
176
- - **Documentation:** Professional model card, research status, and validation reports
177
-
178
- ## Important Notes
179
-
180
- ⚠️ **Experimental Status:** This is research code requiring rigorous baseline validation
181
- ⚠️ **Not Production Ready:** Needs extensive evaluation vs standard transformers
182
- ⚠️ **Research Use Only:** Intended for academic investigation and experimentation
183
-
184
- ## Licensing
185
-
186
- - **Open Source:** AGPLv3 for research and open source use
187
- - **Commercial:** Contact contact@wcnegentropy.com for commercial licensing
188
-
189
- ## Next Steps
190
-
191
- The research community is invited to:
192
- 1. Conduct rigorous baseline comparisons vs standard transformers
193
- 2. Evaluate on established language modeling benchmarks
194
- 3. Validate (or refute) claimed memory efficiency benefits
195
- 4. Share findings openly to advance the field
196
-
197
- **Research responsibly. Validate rigorously. Share openly.**
198
- """
199
-
200
- release_file = Path(__file__).parent / "RELEASE_INFO.md"
201
- with open(release_file, 'w') as f:
202
- f.write(release_info)
203
-
204
- logger.info("Created RELEASE_INFO.md")
205
- return release_file
206
-
207
- if __name__ == "__main__":
208
- # Create release info file
209
- create_release_info()
210
-
211
- # Sync to HuggingFace
212
- success = sync_repository_to_hf()
213
-
214
- if success:
215
- print("\n🚀 BitTransformerLM OS Launch Sync Complete!")
216
- print("📍 Repository: https://huggingface.co/WCNegentropy/BitTransformerLM")
217
- print("📧 Commercial inquiries: contact@wcnegentropy.com")
218
- print("\nReady for research community evaluation! 🧪✨")
219
- else:
220
- print("\n❌ Sync failed. Please check logs and try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/TEST_RESULTS.md DELETED
@@ -1,578 +0,0 @@
1
- # Test Results
2
-
3
- ## Automated Tests
4
- - `pytest -q`: all tests passed.
5
-
6
- ```
7
- .... [100%]
8
- 4 passed in 5.28s
9
- ```
10
-
11
- ## Example Script
12
- - `python example.py` executed successfully:
13
-
14
- ```
15
- Training loss: 0.8508605360984802
16
- Available telemetry: ['activations', 'attention_maps', 'entropy', 'negentropy', 'lz_complexity', 'symbiosis_score']
17
- ```
18
-
19
- ## Progressive Scale-Up
20
- - `python progressive_scaleup.py` (default steps=2) produced:
21
-
22
- ```
23
- Step 0 validation loss: 0.7001
24
- Step 1 validation loss: 0.6954
25
- ```
26
-
27
- ## Text Inference
28
- - Running `infer_text` on a short string returned the input text without errors:
29
-
30
- ```
31
- hi
32
- ```
33
-
34
- ## Extended Scaling Test
35
- Installed torch and ran `python progressive_scaleup.py --steps 4`:
36
-
37
- ```
38
- Step 0 validation loss: 0.6970
39
- Step 1 validation loss: 0.6915
40
- Step 2 validation loss: 0.7022
41
- Step 3 validation loss: 0.7123
42
- ```
43
-
44
- ## Collapse Test
45
- Running a minimal `collapse_submodel` example produced a 2-layer model without errors:
46
-
47
- ```
48
- collapsed_layers 2
49
- ```
50
-
51
-
52
- ## Stress Test 2025
53
- - `pip install -r requirements.txt` succeeded.
54
- - `pytest -q` reported:
55
- ```
56
- 10 passed, 1 skipped
57
- ```
58
-
59
- ### Large Scale-Up
60
- Ran `python progressive_scaleup.py --steps 8 --eps 0.70`:
61
- ```
62
- Step 0 validation loss: 0.7053
63
- Step 1 validation loss: 0.6945
64
- Scaled model to 2 layers and width 32
65
- Step 2 validation loss: 0.6953
66
- Scaled model to 4 layers and width 32
67
- Step 3 validation loss: 0.6820
68
- Scaled model to 8 layers and width 32
69
- Step 4 validation loss: 0.6722
70
- Scaled model to 16 layers and width 32
71
- Step 5 validation loss: 0.6664
72
- Scaled model to 32 layers and width 32
73
- Step 6 validation loss: 0.6663
74
- Scaled model to 64 layers and width 32
75
- Step 7 validation loss: 0.6742
76
- Scaled model to 128 layers and width 32
77
- ```
78
-
79
- ### Collapse Submodel
80
- Using `collapse_submodel` with small clusters produced:
81
- ```
82
- collapsed_layers 3
83
- d_model 16
84
- ```
85
-
86
- ## WikiText Benchmark Attempt
87
- - `pip install -r requirements.txt` succeeded after installing torch 2.7.1+cpu.
88
- - Attempted to download WikiText2 via `datasets` but network access to the S3 bucket was blocked.
89
- - Fallback to random data: ran `python progressive_scaleup.py --steps 12 --width-mult 2.0`:
90
- ```
91
- Step 7 validation loss: 0.6980
92
- Scaled model to 1 layers and width 32
93
- Step 8 validation loss: 0.7022
94
- Scaled model to 1 layers and width 32
95
- Step 9 validation loss: 0.7025
96
- Scaled model to 1 layers and width 32
97
- Step 10 validation loss: 0.7055
98
- Scaled model to 1 layers and width 32
99
- Step 11 validation loss: 0.6976
100
- Scaled model to 1 layers and width 32
101
- ```
102
- - Collapsing a toy cluster produced:
103
- ```
104
- collapsed_layers 1
105
- ```
106
-
107
- ## WikiText Benchmark (datasets)
108
- Using the HuggingFace `datasets` loader with a small subset:
109
- ```
110
- Step 0 validation loss: 0.6237
111
- Scaled model to 2 layers and width 64
112
- Step 1 validation loss: 0.5894
113
- Scaled model to 4 layers and width 128
114
- Step 2 validation loss: 0.5108
115
- Scaled model to 8 layers and width 256
116
- Step 3 validation loss: 0.8422
117
- Collapsed model validation loss: 0.6019973754882812
118
- ```
119
-
120
- ## WikiText Schedule Benchmark
121
- Installed requirements via pip and ran `python wikitext_schedule.py --steps 10 --max-len 16 --dataset-size 10`:
122
- ```
123
- Step 0 validation loss: 0.6686
124
- Scaled model to 2 layers and width 32
125
- Step 1 validation loss: 0.6271
126
- Scaled model to 2 layers and width 64
127
- Step 2 validation loss: 0.7467
128
- Scaled model to 4 layers and width 64
129
- Step 3 validation loss: 0.6571
130
- Scaled model to 4 layers and width 128
131
- Step 4 validation loss: 0.7457
132
- Scaled model to 8 layers and width 128
133
- Step 5 validation loss: 0.8038
134
- Scaled model to 8 layers and width 256
135
- Step 6 validation loss: 2.6579
136
- Scaled model to 16 layers and width 256
137
- Step 7 validation loss: 4.0604
138
- Scaled model to 16 layers and width 512
139
- Step 8 validation loss: 8.6210
140
- Scaled model to 32 layers and width 512
141
- Step 9 validation loss: 6.4301
142
- Scaled model to 32 layers and width 1024
143
- Step 10 validation loss: 11.1592
144
- ```
145
- Attempting the full 12-step run exceeded memory limits and the process was killed after step 10.
146
-
147
- ## Recursive Integration Flow Test
148
- Installed requirements manually and ran `python recursive_integration_flow.py`. Output:
149
-
150
- ```
151
- warnings.warn(
152
- /workspace/Test/recursive_integration_flow.py:87: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
153
- with torch.cpu.amp.autocast(dtype=torch.bfloat16):
154
- Step 0 validation loss: 1.2578 K=0.105 C=0.328 S=0.329
155
- Step 1 validation loss: 0.7305 K=0.031 C=0.095 S=0.244
156
- ⚠️ Step 1 regressed below metric floor. Halting.
157
- Traceback (most recent call last):
158
- File "/workspace/Test/recursive_integration_flow.py", line 119, in <module>
159
- recursive_integration_flow()
160
- File "/workspace/Test/recursive_integration_flow.py", line 93, in recursive_integration_flow
161
- safe_output = hil_safe_inference(
162
- ^^^^^^^^^^^^^^^^^^^
163
- File "/workspace/Test/bit_transformer/safety.py", line 24, in hil_safe_inference
164
- raise RuntimeError(
165
- RuntimeError: Safety gate triggered: C=0.603, S=0.248
166
- ```
167
-
168
- New successful run after adjusting metric floors:
169
-
170
- ```
171
- Step 0 validation loss: 0.7461 K=0.038 C=0.084 S=0.246
172
- Step 1 validation loss: 0.7344 K=0.036 C=0.073 S=0.243
173
- Step 2 validation loss: 0.7266 K=0.029 C=0.074 S=0.242
174
- Step 3 validation loss: 0.7656 K=0.054 C=0.093 S=0.245
175
- Step 4 validation loss: 0.7422 K=0.026 C=0.097 S=0.241
176
- Compilation skipped: Dynamo is not supported on Python 3.12+
177
- Safe output bits: [[1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1]]
178
- ```
179
- New run with torch-2.7.1+cpu installed from requirements and compile disabled:
180
- ```
181
- Step 0 validation loss: 1.8750 K=0.152 C=0.314 S=0.345
182
- Step 1 validation loss: 1.0625 K=0.305 C=0.101 S=0.302
183
- Step 2 validation loss: 0.7266 K=0.028 C=0.083 S=0.244
184
- Step 3 validation loss: 0.7773 K=0.045 C=0.175 S=0.254
185
- Step 4 validation loss: 0.7539 K=0.031 C=0.122 S=0.245
186
- Safe output bits: [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0]]
187
- ```
188
- Run with pinned dependencies from updated `requirements.txt`:
189
- ```
190
- Step 0 validation loss: 2.4531 K=0.195 C=0.287 S=0.346
191
- Step 1 validation loss: 1.5781 K=0.176 C=0.307 S=0.340
192
- Step 2 validation loss: 0.7383 K=0.037 C=0.112 S=0.245
193
- Step 3 validation loss: 0.7773 K=0.038 C=0.178 S=0.251
194
- Step 4 validation loss: 0.7227 K=0.028 C=0.099 S=0.239
195
- Safe output bits: [[1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1]]
196
- ```
197
-
198
- ## WikiText Schedule with Compression
199
- Ran `python wikitext_schedule.py --steps 2 --dataset-size 64` using the new compression-aware training.
200
-
201
- ```
202
- Step 0 validation loss: 0.6969
203
- Scaled model to 2 layers and width 32
204
- Step 1 validation loss: 0.6840
205
- Scaled model to 2 layers and width 64
206
- Step 2 validation loss: 0.6746
207
- ```
208
- ## WikiText Schedule 10-step Run with Compression
209
- Step 0 validation loss: 2.1250
210
- Scaled model to 2 layers and width 32
211
- Step 1 validation loss: 2.2188
212
- Scaled model to 2 layers and width 64
213
- Step 2 validation loss: 6.0000
214
- Scaled model to 4 layers and width 64
215
- Step 3 validation loss: 6.3750
216
- Scaled model to 4 layers and width 128
217
- Step 4 validation loss: 4.7812
218
- Scaled model to 8 layers and width 128
219
- Step 5 validation loss: 3.8594
220
- Scaled model to 8 layers and width 256
221
- Step 6 validation loss: 7.2812
222
- Scaled model to 16 layers and width 256
223
- Step 7 validation loss: 9.8125
224
- Scaled model to 16 layers and width 512
225
- Step 8 validation loss: 34.5000
226
- Scaled model to 32 layers and width 512
227
- Step 9 validation loss: 39.7500
228
- Scaled model to 32 layers and width 1024
229
- Step 10 validation loss: 163.0000
230
-
231
- ### 10-step Run with ACT Enabled
232
- Attempted to rerun the 10-step schedule with `use_act=True` and dataset size 128.
233
- Training was interrupted due to time limits after step 8. Partial results:
234
- ```
235
- Step 0 validation loss: 1.8594
236
- Scaled model to 2 layers and width 32
237
- Step 1 validation loss: 0.7344
238
- Scaled model to 2 layers and width 64
239
- Step 2 validation loss: 0.5469
240
- Scaled model to 4 layers and width 64
241
- Step 3 validation loss: 0.2520
242
- Scaled model to 4 layers and width 128
243
- Step 4 validation loss: 0.1748
244
- Scaled model to 8 layers and width 128
245
- Step 5 validation loss: 0.0284
246
- Scaled model to 8 layers and width 256
247
- Step 6 validation loss: 0.1982
248
- Scaled model to 16 layers and width 256
249
- Step 7 validation loss: 0.1562
250
- Scaled model to 16 layers and width 512
251
- Step 8 validation loss: 0.2168
252
- Scaled model to 32 layers and width 512
253
- ```
254
-
255
- ## WikiText-103 100MB Attempt
256
- Attempted to run training with 100MB of WikiText-103 data streamed via `datasets` and converted to bits. Converting the dataset (352k lines) took too long and the process was interrupted before the first training step could complete.
257
-
258
-
259
- ## Offline Full Bits Training Attempt
260
- - Installed requirements successfully.
261
- - Built `full_bits.pt` (100MB WikiText-103 compressed to bits).
262
- - Ran `python full_bits_train.py` but the training loop was extremely slow and was manually interrupted before completing a single pass.
263
-
264
- ## BitSeq Dataset Training
265
- - Built `full_bits.pt` from WikiText2 using `build_full_bits.py`.
266
- - Ran `python full_bits_train.py` with BitSeq DataLoader (seq=2048, batch=8).
267
- - The script loaded one batch and reported `Batch loss: 2.4375`.
268
-
269
- ## Offline train_full_sequence Scale-Up (8 steps)
270
- - Built dataset with `python build_full_bits.py` (~84MB).
271
- - Trained using `BitTransformerLM.train_full_sequence` over the first 65k bits with ctx_bits=64.
272
- ```
273
- Step 0 train loss: 3.7605
274
- Step 1 train loss: 3.7545
275
- Step 2 train loss: 3.7434
276
- Step 3 train loss: 3.7382
277
- Step 4 train loss: 3.7301
278
- Step 5 train loss: 3.7261
279
- Step 6 train loss: 3.7202
280
- Step 7 train loss: 3.7060
281
- ```
282
-
283
- ## Progressive Scale-Up 8-Step Run
284
- ```
285
- Step 0 validation loss: 0.7042
286
- Step 1 validation loss: 0.7036
287
- Step 2 validation loss: 0.7061
288
- Step 3 validation loss: 0.6997
289
- Step 4 validation loss: 0.7072
290
- Step 5 validation loss: 0.6892
291
- Step 6 validation loss: 0.7085
292
- Step 7 validation loss: 0.6966
293
- ```
294
-
295
- ## Compression Inference Test
296
- Installed requirements and ran `python wikitext_schedule.py --steps 2 --dataset-size 64`:
297
- ```
298
- Step 0 validation loss: 0.9297
299
- Scaled model to 2 layers and width 32
300
- Step 1 validation loss: 0.7773
301
- Scaled model to 2 layers and width 64
302
- Step 2 validation loss: 0.7773
303
- ```
304
-
305
- Ran a minimal training cycle with compression and generated text from the model:
306
- ```
307
- Model output: hllo world
308
- ```
309
-
310
-
311
- ## Bigger Batch Smoke Test
312
- Executed `python unified_workflow.py --steps 9 --dataset-size 100` after adding warm-up optimisation. Final lines:
313
- ```
314
- Epoch 1 raw_loss=0.5525 acc=0.692 | compressed_loss=0.5449 acc=0.718 direct_loss=0.0000 ratio=1.07
315
- Step 8 validation loss: 0.4727 K=0.248 C=0.126 S=0.309
316
- Final validation loss: 0.4824 K=0.245 C=0.131 S=0.308
317
- Safety gate triggered Safety gate triggered: C=0.476, S=0.292
318
- Collapsed model validation loss: 0.6928360462188721
319
- ```
320
-
321
- ### Inference Conversation
322
- ```
323
- User: hi
324
- Model: hi
325
- User: ok
326
- Model: ok
327
- ```
328
-
329
- ## Bigger Training Smoke Test
330
-
331
- Executed `python unified_workflow.py --steps 7 --dataset-size 64` after updating
332
- the training loop with extra optimizer steps. Final lines:
333
-
334
- ```
335
- Step 6 validation loss: 0.4922 K=0.252 C=0.118 S=0.306
336
- Final validation loss: 0.4785 K=0.264 C=0.105 S=0.307
337
- Safety gate triggered Safety gate triggered: C=0.476, S=0.297
338
- Collapsed model validation loss: 0.6666421890258789
339
- Workflow results: [(0, 1.015625, 0.2431640625, 0.126953125, 0.30909082293510437), (1, 0.74609375, 0.04248046875, 0.0306396484375, 0.2524452209472656), (2, 0.66796875, 0.11181640625, 0.06396484375, 0.2690799832344055), (3, 0.734375, 0.095703125, 0.044189453125, 0.2644684910774231), (4, 0.5546875, 0.220703125, 0.08837890625, 0.29613998532295227), (5, 0.73046875, 0.03759765625, 0.0654296875, 0.25516262650489807), (6, 0.4921875, 0.251953125, 0.11767578125, 0.30603474378585815), (7, 0.478515625, 0.263671875, 0.10498046875, 0.3072776794433594)]
340
- ```
341
-
342
- ### Inference Conversation (temperature=0.9, top-p=0.95)
343
-
344
- ```
345
- User: hi
346
- Model: hi
347
- User: how are you?
348
- Model: how are you?
349
- ```
350
-
351
- ## Continuous Training Test
352
- Loaded existing weights when present.
353
- Performed 2 scaling steps and 1 plateau step on a 16-sample dataset.
354
- Final validation loss: 0.7383 with the collapsed model at 0.6924.
355
-
356
- ## Diffusion LM Smoke Test
357
- Installed requirements and ran `python unified_workflow.py --steps 2 --dataset-size 32 --max-len 32 --diffusion`:
358
- ```
359
- Epoch 0 raw_loss=4.7188 acc=0.188 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
360
- Epoch 1 raw_loss=4.6094 acc=0.185 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
361
- Step 0 validation loss: 3.9844 K=0.311 C=0.109 S=0.351
362
- Epoch 0 raw_loss=3.6445 acc=0.355 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
363
- Epoch 1 raw_loss=2.4531 acc=0.544 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
364
- Step 1 validation loss: 3.2656 K=0.371 C=0.088 S=0.357
365
- Final validation loss: 3.2344 K=0.373 C=0.087 S=0.357
366
- Diffusion sample: [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]
367
- Diffusion inference output bits: [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
368
- ```
369
-
370
- ## Rigorous Training Regime
371
- Ran `python tests/rigorous_training_regime.py`:
372
-
373
- ```
374
- ### Progressive Scale-Up (causal=True)
375
-
376
- Step 0 validation loss: 0.7167
377
- Scaled model to 1 layers and width 32
378
- Step 1 validation loss: 0.6880
379
- Scaled model to 1 layers and width 32
380
- Step 2 validation loss: 0.7019
381
- Scaled model to 1 layers and width 32
382
- Duration: 0.23s
383
-
384
- ### Progressive Scale-Up (causal=False)
385
-
386
- Step 0 validation loss: 0.8581
387
- Scaled model to 1 layers and width 32
388
- Step 1 validation loss: 0.7439
389
- Scaled model to 1 layers and width 32
390
- Step 2 validation loss: 0.7068
391
- Scaled model to 1 layers and width 32
392
- Duration: 0.21s
393
-
394
- ### Unified Workflow (causal=True)
395
-
396
- Loaded model from weights/model.pt.gz
397
- Epoch 0 raw_loss=0.6719 acc=0.581 | compressed_loss=0.6875 acc=0.586 direct_loss=0.0000 ratio=1.09
398
- Step 0 validation loss: 0.6367 K=0.091 C=0.069 S=0.284
399
- Epoch 0 raw_loss=0.6328 acc=0.605 | compressed_loss=0.6328 acc=0.612 direct_loss=0.0000 ratio=1.09
400
- Step 1 validation loss: 0.6914 K=0.202 C=0.049 S=0.305
401
- Epoch 0 raw_loss=0.5312 acc=0.718 | compressed_loss=0.6445 acc=0.628 direct_loss=0.0000 ratio=1.09
402
- Plateau 0 validation loss: 0.5469 K=0.096 C=0.118 S=0.290
403
- Final validation loss: 0.5430 K=0.099 C=0.104 S=0.289
404
- Safety gate triggered Safety gate triggered: C=0.484, S=0.285
405
- Collapsed model validation loss: 0.8396304845809937
406
- Workflow results: [(0, 0.63671875, 0.09130859375, 0.0693359375, 0.28369221091270447), (1, 0.69140625, 0.2021484375, 0.049072265625, 0.3053092062473297), (2, 0.546875, 0.09619140625, 0.1181640625, 0.2900315225124359), (3, 0.54296875, 0.09912109375, 0.10400390625, 0.289362370967865)]
407
- Inference on 'hi': [0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1]
408
-
409
- Duration: 8.48s
410
-
411
- ### Unified Workflow (causal=False / Diffusion)
412
-
413
- Loaded model from weights/model.pt.gz
414
- Epoch 0 raw_loss=0.8232 acc=0.391 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
415
- Step 0 validation loss: 0.9805 K=0.098 C=0.067 S=0.285
416
- Epoch 0 raw_loss=0.7471 acc=0.561 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
417
- Step 1 validation loss: 1.0547 K=0.134 C=0.091 S=0.294
418
- Epoch 0 raw_loss=0.7520 acc=0.609 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
419
- Plateau 0 validation loss: 0.2119 K=0.187 C=0.185 S=0.332
420
- Final validation loss: 0.2188 K=0.187 C=0.176 S=0.330
421
- Collapsed model validation loss: 0.6897413730621338
422
- Diffusion sample: [1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1]
423
- Workflow results: [(0, 0.98046875, 0.09765625, 0.06689453125, 0.28478696942329407), (1, 1.0546875, 0.1337890625, 0.0908203125, 0.29406091570854187), (2, 0.2119140625, 0.1865234375, 0.1845703125, 0.33178743720054626), (3, 0.21875, 0.1865234375, 0.17578125, 0.32961323857307434)]
424
- Diffusion inference output bits: [1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1]
425
- Duration: 24.25s
426
- ```
427
-
428
- ## Rigorous Training Regime (2025-08-06)
429
- Ran `python tests/rigorous_training_regime.py`:
430
-
431
- ```
432
- ### Progressive Scale-Up (causal=True)
433
-
434
- Step 0 validation loss: 0.6921
435
- Scaled model to 1 layers and width 32
436
- Step 1 validation loss: 0.7171
437
- Scaled model to 1 layers and width 32
438
- Step 2 validation loss: 0.6914
439
- Scaled model to 1 layers and width 32
440
- Duration: 0.27s
441
-
442
- ### Progressive Scale-Up (causal=False)
443
-
444
- Step 0 validation loss: 0.8465
445
- Scaled model to 1 layers and width 32
446
- Step 1 validation loss: 0.7123
447
- Scaled model to 1 layers and width 32
448
- Step 2 validation loss: 0.7009
449
- Scaled model to 1 layers and width 32
450
- Duration: 0.26s
451
-
452
- ### Unified Workflow (causal=True)
453
-
454
- Epoch 0 raw_loss=1.1094 acc=0.593 | compressed_loss=1.1465 acc=0.599 direct_loss=0.0000 ratio=1.09
455
- Step 0 validation loss: 0.8945 K=0.301 C=0.092 S=0.339
456
- Epoch 0 raw_loss=0.9453 acc=0.601 | compressed_loss=0.9707 acc=0.617 direct_loss=0.0000 ratio=1.09
457
- Step 1 validation loss: 0.9180 K=0.301 C=0.088 S=0.338
458
- Epoch 0 raw_loss=0.8984 acc=0.593 | compressed_loss=0.9590 acc=0.599 direct_loss=0.0000 ratio=1.09
459
- Plateau 0 validation loss: 0.7969 K=0.243 C=0.095 S=0.324
460
- Final validation loss: 0.7930 K=0.244 C=0.094 S=0.324
461
- Safety gate triggered Safety gate triggered: C=0.484, S=0.314
462
- Collapsed model validation loss: 0.6552348732948303
463
- Workflow results: [(0, 0.89453125, 0.30078125, 0.09228515625, 0.33890560269355774), (1, 0.91796875, 0.30078125, 0.08837890625, 0.33844876289367676), (2, 0.796875, 0.2431640625, 0.0947265625, 0.32405367493629456), (3, 0.79296875, 0.244140625, 0.09423828125, 0.32419103384017944)]
464
- Inference on 'hi': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
465
-
466
- Duration: 5.26s
467
-
468
- ### Unified Workflow (causal=False / Diffusion)
469
-
470
- Loaded model from weights/model.pt.gz
471
- Epoch 0 raw_loss=1.2266 acc=0.590 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
472
- Step 0 validation loss: 0.8359 K=0.165 C=0.032 S=0.296
473
- Epoch 0 raw_loss=0.7617 acc=0.603 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
474
- Step 1 validation loss: 0.7891 K=0.025 C=0.043 S=0.268
475
- Epoch 0 raw_loss=0.7158 acc=0.553 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
476
- Plateau 0 validation loss: 0.5391 K=0.113 C=0.056 S=0.287
477
- Final validation loss: 0.5391 K=0.116 C=0.060 S=0.287
478
- Collapsed model validation loss: 0.7268564701080322
479
- Diffusion sample: [1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]
480
- Workflow results: [(0, 0.8359375, 0.1650390625, 0.0322265625, 0.29598498344421387), (1, 0.7890625, 0.0250244140625, 0.04345703125, 0.26766154170036316), (2, 0.5390625, 0.11328125, 0.05615234375, 0.2867652475833893), (3, 0.5390625, 0.1162109375, 0.06005859375, 0.28735819458961487)]
481
- Diffusion inference output bits: [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0]
482
- Duration: 3.70s
483
- ```
484
-
485
- ## Rigorous Training Regime (2025-08-06 - 10-step alt length/width)
486
- Ran `python tests/rigorous_training_regime.py`:
487
-
488
- ```
489
- ### Progressive Scale-Up (causal=True)
490
-
491
- Step 0 validation loss: 0.4615
492
- Step 1 validation loss: 0.4427
493
- Step 2 validation loss: 0.4282
494
- Step 3 validation loss: 0.4202
495
- Step 4 validation loss: 0.4175
496
- Scaled length; seq_len=128 width=32 params=8674
497
- Step 5 validation loss: 0.5383
498
- Scaled width; seq_len=128 width=64 params=33730
499
- Step 6 validation loss: 0.4334
500
- Step 7 validation loss: 0.4304
501
- Scaled length; seq_len=256 width=64 params=33730
502
- Step 8 validation loss: 0.5085
503
- Scaled width; seq_len=256 width=128 params=132994
504
- Step 9 validation loss: 0.4279
505
- Duration: 38.96s
506
-
507
- ### Progressive Scale-Up (causal=False)
508
-
509
- Step 0 validation loss: 0.4292
510
- Step 1 validation loss: 0.4053
511
- Step 2 validation loss: 0.4003
512
- Step 3 validation loss: 0.3997
513
- Scaled length; seq_len=128 width=32 params=8674
514
- Step 4 validation loss: 0.4162
515
- Scaled width; seq_len=128 width=64 params=33730
516
- Step 5 validation loss: 0.4173
517
- Scaled length; seq_len=256 width=64 params=33730
518
- Step 6 validation loss: 0.4160
519
- Scaled width; seq_len=256 width=128 params=132994
520
- Step 7 validation loss: 0.4211
521
- Scaled length; seq_len=512 width=128 params=132994
522
- Step 8 validation loss: 0.4227
523
- Scaled width; seq_len=512 width=256 params=528130
524
- Step 9 validation loss: 0.4146
525
- Duration: 173.71s
526
-
527
- ### Unified Workflow (causal=True)
528
-
529
- Epoch 0 raw_loss=3.1562 acc=0.540 | compressed_loss=3.4531 acc=0.529 direct_loss=0.0000 ratio=1.09
530
- Step 0 validation loss: 2.9688 K=0.559 C=0.220 S=0.475
531
- Epoch 0 raw_loss=2.7188 acc=0.540 | compressed_loss=2.9883 acc=0.529 direct_loss=0.0000 ratio=1.09
532
- Step 1 validation loss: 3.4531 K=0.566 C=0.222 S=0.481
533
- Epoch 0 raw_loss=3.0625 acc=0.540 | compressed_loss=3.4414 acc=0.529 direct_loss=0.0000 ratio=1.09
534
- Plateau 0 validation loss: 3.0781 K=0.559 C=0.219 S=0.474
535
- Final validation loss: 3.0938 K=0.559 C=0.220 S=0.475
536
- Safety gate triggered Safety gate triggered: C=0.484, S=0.466
537
- Collapsed model validation loss: 0.6677278280258179
538
- Workflow results: [(0, 2.96875, 0.55859375, 0.2197265625, 0.4746275246143341), (1, 3.453125, 0.56640625, 0.2216796875, 0.4808752238750458), (2, 3.078125, 0.55859375, 0.21875, 0.47436484694480896), (3, 3.09375, 0.55859375, 0.2197265625, 0.474519282579422)]
539
- Inference on 'hi': [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1]
540
-
541
- Duration: 2.50s
542
-
543
- ### Unified Workflow (causal=False / Diffusion)
544
-
545
- Loaded model from weights/model.pt.gz
546
- Epoch 0 raw_loss=4.3984 acc=0.271 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
547
- Step 0 validation loss: 4.9688 K=0.512 C=0.208 S=0.449
548
- Epoch 0 raw_loss=3.5859 acc=0.225 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
549
- Step 1 validation loss: 4.6562 K=0.477 C=0.200 S=0.428
550
- Epoch 0 raw_loss=3.3008 acc=0.225 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
551
- Plateau 0 validation loss: 3.5469 K=0.439 C=0.158 S=0.396
552
- Final validation loss: 3.5625 K=0.436 C=0.156 S=0.396
553
- Collapsed model validation loss: 0.6747412085533142
554
- Diffusion sample: [1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1]
555
- Workflow results: [(0, 4.96875, 0.51171875, 0.2080078125, 0.44865939021110535), (1, 4.65625, 0.4765625, 0.2001953125, 0.4284386932849884), (2, 3.546875, 0.439453125, 0.158203125, 0.3957676589488983), (3, 3.5625, 0.435546875, 0.15625, 0.39555999636650085)]
556
- Diffusion inference output bits: [1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]
557
- Duration: 3.42s
558
- ```
559
-
560
- ## WikiText Training Attempt (2025-09-??)
561
- Attempted minimal training on real WikiText-2 data using `train_loop` with dropout 0.1 and evaluation dropout 0.0. Training failed due to a telemetry shape mismatch:
562
-
563
- ```
564
- RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
565
- ```
566
-
567
- As a sanity check, ran `hil_safe_inference` on an untrained model in evaluation mode (dropout=0.0):
568
-
569
- ```
570
- Inference output bits: [[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
571
- ```
572
-
573
- ## WikiText Training Debug (2025-09-??)
574
- Ran a minimal `train_loop` on parity-protected WikiText-2 samples with dropout 0.1:
575
-
576
- ```
577
- Epoch 0 raw_loss=0.6278 acc=0.724 | compressed_loss=0.0000 acc=0.000 direct_loss=0.0000 ratio=0.00
578
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unified_workflow.py DELETED
@@ -1,165 +0,0 @@
1
- import argparse
2
- import os
3
- import subprocess
4
- import sys
5
- import time
6
- import torch
7
- from bit_transformer.utils import load_model
8
- from bit_transformer.hf_checkpoint import (
9
- hf_login,
10
- save_checkpoint,
11
- download_checkpoint,
12
- )
13
- from bit_transformer import diffusion_inference
14
- from bit_transformer.cli_standards import create_workflow_parser, BitTransformerCLI
15
-
16
- from integration_schedule import integration_schedule
17
-
18
-
19
- def _launch_dashboard() -> list[subprocess.Popen]:
20
- """Start MCP server and dashboard processes."""
21
- server = subprocess.Popen([sys.executable, "mcp_server.py"])
22
- time.sleep(2)
23
- dash_env = dict(os.environ)
24
- dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
25
- dashboard = subprocess.Popen(
26
- [sys.executable, "-m", "bit_transformer.dashboard_app"],
27
- env=dash_env,
28
- )
29
- return [server, dashboard]
30
-
31
-
32
- def _terminate(procs: list[subprocess.Popen]) -> None:
33
- for p in procs:
34
- p.terminate()
35
- try:
36
- p.wait(timeout=5)
37
- except Exception:
38
- p.kill()
39
-
40
-
41
- def run_workflow(
42
- steps: int = 10,
43
- max_len: int = 64,
44
- dataset_size: int = 128,
45
- *,
46
- launch_ui: bool = False,
47
- weights_path: str = "weights/model.pt.gz",
48
- collapsed_path: str = "weights/collapsed.pt.gz",
49
- plateau_steps: int = 0,
50
- epochs_per_step: int = 2,
51
- extra_steps: int = 3,
52
- collapse: bool = True,
53
- hf_repo: str | None = None,
54
- hf_token: str | None = None,
55
- diffusion: bool = False,
56
- noise_schedule: str = "linear",
57
- diffusion_steps: int = 8,
58
- diffusion_curriculum: bool = False,
59
- use_checkpoint: bool = True,
60
- reversible: bool = True,
61
- qat: bool = False,
62
- ) -> tuple:
63
- """Run the full integration schedule with optional dashboard.
64
-
65
- If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
66
- before being converted to quantized weights for safety checks.
67
- """
68
- procs: list[subprocess.Popen] = []
69
- if launch_ui:
70
- procs = _launch_dashboard()
71
- if hf_repo:
72
- hf_login(token=hf_token)
73
- if not os.path.exists(weights_path):
74
- download_checkpoint(weights_path, repo_id=hf_repo)
75
- try:
76
- results, collapsed = integration_schedule(
77
- steps=steps,
78
- max_len=max_len,
79
- dataset_size=dataset_size,
80
- weights_path=weights_path,
81
- plateau_steps=plateau_steps,
82
- collapsed_path=collapsed_path,
83
- epochs_per_step=epochs_per_step,
84
- extra_steps=extra_steps,
85
- collapse=collapse,
86
- diffusion=diffusion,
87
- noise_schedule=noise_schedule,
88
- diffusion_steps=diffusion_steps,
89
- diffusion_curriculum=diffusion_curriculum,
90
- use_checkpoint=use_checkpoint,
91
- reversible=reversible,
92
- qat=qat,
93
- )
94
- model = load_model(weights_path)
95
- print("Workflow results:", results)
96
- if diffusion:
97
- sample = diffusion_inference(
98
- model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
99
- )
100
- print("Diffusion inference output bits:", sample[0].tolist())
101
- if hf_repo:
102
- save_checkpoint(model, repo_id=hf_repo)
103
- finally:
104
- if launch_ui:
105
- _terminate(procs)
106
- return model, collapsed
107
-
108
-
109
- if __name__ == "__main__":
110
- # Use standardized CLI parser
111
- parser = create_workflow_parser()
112
-
113
- # Add workflow-specific arguments
114
- workflow_group = parser.add_argument_group('Workflow Configuration')
115
- workflow_group.add_argument("--steps", type=int, default=10,
116
- help="Number of progressive scale-up steps")
117
- workflow_group.add_argument("--plateau-steps", type=int, default=0,
118
- help="Extra training steps at final size")
119
- workflow_group.add_argument("--epochs-per-step", type=int, default=2,
120
- help="Epochs per training step")
121
- workflow_group.add_argument("--extra-steps", type=int, default=3,
122
- help="Optimizer updates after each epoch")
123
- workflow_group.add_argument("--no-collapse", action="store_true",
124
- help="Skip collapsed model generation")
125
- workflow_group.add_argument("--dashboard", action="store_true",
126
- help="Launch MCP server and dashboard UI")
127
-
128
- # Add advanced optimization arguments
129
- opt_group = parser.add_argument_group('Advanced Optimization')
130
- opt_group.add_argument("--no-checkpoint", action="store_true",
131
- help="Disable gradient checkpointing (faster but more memory)")
132
- opt_group.add_argument("--no-reversible", action="store_true",
133
- help="Use standard transformer blocks instead of reversible layers")
134
- opt_group.add_argument("--qat", action="store_true",
135
- help="Enable 4-bit quantization-aware training")
136
-
137
- # Override some defaults for workflow context
138
- parser.set_defaults(
139
- seq_length=64, # Use seq-length instead of max-len
140
- dataset_size=128,
141
- weights_path="weights/model.pt.gz"
142
- )
143
- args = parser.parse_args()
144
-
145
- run_workflow(
146
- args.steps,
147
- args.seq_length, # Standardized name
148
- args.dataset_size,
149
- launch_ui=args.dashboard,
150
- weights_path=args.weights_path,
151
- collapsed_path=getattr(args, 'collapsed_path', 'weights/collapsed.pt.gz'),
152
- plateau_steps=args.plateau_steps,
153
- epochs_per_step=args.epochs_per_step,
154
- extra_steps=args.extra_steps,
155
- collapse=not args.no_collapse,
156
- hf_repo=args.hf_repo,
157
- hf_token=args.hf_token,
158
- diffusion=args.diffusion_mode, # Standardized name
159
- noise_schedule=args.noise_schedule,
160
- diffusion_steps=args.diffusion_steps,
161
- diffusion_curriculum=args.diffusion_curriculum,
162
- use_checkpoint=not args.no_checkpoint,
163
- reversible=not args.no_reversible,
164
- qat=args.qat,
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
watcher.py DELETED
@@ -1,80 +0,0 @@
1
- import argparse
2
- import subprocess
3
- import sys
4
- import time
5
- from pathlib import Path
6
- from watchdog.events import FileSystemEventHandler
7
- from watchdog.observers import Observer
8
-
9
-
10
- class RestartOnChange(FileSystemEventHandler):
11
- """Restart a subprocess when watched files change."""
12
-
13
- def __init__(self, command: list[str], watch_paths: list[str]) -> None:
14
- self.command = command
15
- self.watch_paths = [Path(p).resolve() for p in watch_paths]
16
- self.process: subprocess.Popen | None = None
17
- self.restart()
18
-
19
- def restart(self) -> None:
20
- if self.process and self.process.poll() is None:
21
- self.process.terminate()
22
- try:
23
- self.process.wait(timeout=5)
24
- except subprocess.TimeoutExpired:
25
- self.process.kill()
26
- self.process.wait()
27
- self.process = subprocess.Popen(self.command)
28
-
29
- def on_any_event(self, event) -> None: # pragma: no cover - runtime utility
30
- if event.is_directory:
31
- return
32
- path = Path(event.src_path)
33
- if path.suffix != ".py":
34
- return
35
- if any(str(path).startswith(str(p)) for p in self.watch_paths):
36
- print(f"[watcher] {path} changed, running tests...")
37
- subprocess.run([sys.executable, "-m", "pytest", "-q"])
38
- print("[watcher] restarting process...")
39
- self.restart()
40
-
41
-
42
- def main() -> None: # pragma: no cover - CLI entry
43
- parser = argparse.ArgumentParser(
44
- description="Watch files and restart a command on changes",
45
- )
46
- parser.add_argument(
47
- "--command",
48
- nargs="+",
49
- default=[sys.executable, "mcp_server.py"],
50
- help="Command to run",
51
- )
52
- parser.add_argument(
53
- "--paths",
54
- nargs="+",
55
- default=["bit_transformer", "mcp_server.py"],
56
- help="Paths to watch for changes",
57
- )
58
- args = parser.parse_args()
59
-
60
- observer = Observer()
61
- handler = RestartOnChange(args.command, args.paths)
62
- for p in args.paths:
63
- observer.schedule(handler, p, recursive=True)
64
- observer.start()
65
- try:
66
- while True:
67
- time.sleep(1)
68
- except KeyboardInterrupt:
69
- pass
70
- finally:
71
- observer.stop()
72
- handler.restart()
73
- if handler.process and handler.process.poll() is None:
74
- handler.process.terminate()
75
- handler.process.wait()
76
- observer.join()
77
-
78
-
79
- if __name__ == "__main__": # pragma: no cover - CLI entry
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wikitext_benchmark.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from datasets import load_dataset
4
- from bit_transformer import text_to_bits, collapse_submodel
5
- from progressive_scaleup import progressive_scale_up_text
6
-
7
-
8
- def lines_to_bits(lines, max_len=64):
9
- data = []
10
- for text in lines:
11
- bits = text_to_bits(text)[:max_len]
12
- if len(bits) < max_len:
13
- bits.extend([0] * (max_len - len(bits)))
14
- data.append(bits)
15
- return data
16
-
17
-
18
- def main():
19
- ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
20
- val_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation[:1%]")
21
- train_lines = [item["text"] for item in ds][:256]
22
- valid_lines = [item["text"] for item in val_ds][:64]
23
-
24
- train_bits = lines_to_bits(train_lines)
25
- valid_bits = lines_to_bits(valid_lines)
26
-
27
- progressive_scale_up_text(
28
- eps=0.65,
29
- steps=4,
30
- width_mult=2.0,
31
- max_len=64,
32
- dataset_size=min(64, len(train_bits)),
33
- )
34
-
35
- target_params = dict(d_model=16, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=64)
36
- model, _ = collapse_submodel(train_bits[:64], target_params, max_rounds=1)
37
-
38
- val_tensor = torch.tensor(valid_bits, dtype=torch.long)
39
- logits, _ = model(val_tensor)
40
- pred = logits[:, :-1, :].reshape(-1, 2)
41
- target = val_tensor[:, 1:].reshape(-1)
42
- loss = F.cross_entropy(pred, target)
43
- print("Collapsed model validation loss:", loss.item())
44
-
45
-
46
- if __name__ == "__main__":
47
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wikitext_schedule.py DELETED
@@ -1,130 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch.utils.data import Dataset
5
- from pathlib import Path
6
- from datasets import load_dataset
7
-
8
- from bit_transformer import (
9
- BitTransformerLM,
10
- configure_optimizer,
11
- expand_model,
12
- text_to_bits,
13
- )
14
- from bit_transformer.training import train_loop as basic_train
15
-
16
-
17
- def _build_memmap(lines, path: Path, max_len: int) -> None:
18
- """Precompute bit tensors into a memory-mapped file."""
19
- arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8")
20
- for idx, text in enumerate(lines):
21
- bits = text_to_bits(text)[:max_len]
22
- if len(bits) < max_len:
23
- bits.extend([0] * (max_len - len(bits)))
24
- arr[idx] = np.array(bits, dtype="uint8")
25
- arr.flush()
26
-
27
-
28
- class MemmapDataset(Dataset):
29
- """Dataset backed by a memory-mapped array."""
30
-
31
- def __init__(self, path: Path, length: int, max_len: int) -> None:
32
- self.path = path
33
- self.length = length
34
- self.max_len = max_len
35
- self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8")
36
-
37
- def __len__(self) -> int: # pragma: no cover - trivial
38
- return self.length
39
-
40
- def __getitem__(self, idx: int) -> torch.Tensor:
41
- return torch.from_numpy(self._arr[idx].astype("int64"))
42
-
43
-
44
- def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128):
45
- """Run deterministic scale-up on WikiText data."""
46
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
47
- train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
48
- valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4]
49
-
50
- train_path = Path("wikitext_train.memmap")
51
- valid_path = Path("wikitext_valid.memmap")
52
- _build_memmap(train_lines, train_path, max_len)
53
- _build_memmap(valid_lines, valid_path, max_len)
54
-
55
- train = MemmapDataset(train_path, len(train_lines), max_len)
56
- valid = torch.from_numpy(
57
- np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8")
58
- ).long()
59
-
60
- layers = 1
61
- width = 32
62
- params = dict(
63
- d_model=width,
64
- nhead=4,
65
- num_layers=layers,
66
- dim_feedforward=width * 2,
67
- max_seq_len=max_len,
68
- reversible=True,
69
- chunk_size=max_len,
70
- use_autocast=True,
71
- use_act=True,
72
- act_threshold=0.9,
73
- )
74
- model = BitTransformerLM(**params)
75
- steps_per_epoch = max(1, (len(train) + 7) // 8)
76
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch)
77
-
78
- results = []
79
- for step in range(steps + 1):
80
- basic_train(
81
- model,
82
- train,
83
- epochs=1,
84
- compress_prob=0.5,
85
- log=False,
86
- forward_kwargs=None,
87
- num_workers=2,
88
- )
89
-
90
- with torch.no_grad():
91
- logits, _ = model(valid)
92
- pred = logits[:, :-1, :].reshape(-1, 2)
93
- target = valid[:, 1:].reshape(-1)
94
- val_loss = F.cross_entropy(pred, target).item()
95
- print(f"Step {step} validation loss: {val_loss:.4f}")
96
- results.append((step, val_loss))
97
-
98
- if step < steps:
99
- if step % 2 == 0:
100
- layers *= 2
101
- else:
102
- width *= 2
103
- params = dict(
104
- d_model=width,
105
- nhead=4,
106
- num_layers=layers,
107
- dim_feedforward=width * 2,
108
- max_seq_len=max_len,
109
- reversible=True,
110
- chunk_size=max_len,
111
- use_autocast=True,
112
- use_act=True,
113
- act_threshold=0.9,
114
- )
115
- model = expand_model(model, params)
116
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch)
117
- print(f"Scaled model to {layers} layers and width {width}")
118
- return results
119
-
120
-
121
- if __name__ == "__main__":
122
- import argparse
123
-
124
- parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark")
125
- parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps")
126
- parser.add_argument("--max-len", type=int, default=64, help="sequence length")
127
- parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines")
128
- args = parser.parse_args()
129
-
130
- progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size)