BitTransformerLM / markov_spline_cli.py
WCNegentropy's picture
πŸš€ OS Launch: Clean documentation and refined licensing
b9246a0 verified
raw
history blame
11.7 kB
#!/usr/bin/env python3
"""
MarkovSpline CLI Interface for BitTransformerLM Integration
Provides command-line tools for using MarkovSpline data smoothing
with BitTransformerLM training and inference pipelines.
"""
import argparse
import sys
import os
import json
import numpy as np
import torch
from pathlib import Path
from typing import List, Dict, Any, Optional
# Add MarkovSpline to path
sys.path.insert(0, '/data/MarkovSpline')
from bitpipe_integration import MarkovSplineBitPipeModule, create_markov_spline_bitpipe_module
from core import SplineType
# Simple text to bits converter for CLI
class TextToBitsConverter:
"""Simple text to bits converter."""
def text_to_bits(self, text, max_length=128):
"""Convert text to bit sequence."""
bit_sequence = []
for char in text[:max_length//8]:
char_bits = format(ord(char), '08b')
bit_sequence.extend([int(b) for b in char_bits])
# Pad or truncate to max_length
if len(bit_sequence) < max_length:
bit_sequence.extend([0] * (max_length - len(bit_sequence)))
else:
bit_sequence = bit_sequence[:max_length]
return bit_sequence
class MarkovSplineBitTransformerCLI:
"""CLI interface for MarkovSpline + BitTransformerLM integration."""
def __init__(self):
self.markov_module = None
self.text_converter = TextToBitsConverter()
def initialize_markov_spline(self, config: Optional[Dict] = None) -> bool:
"""Initialize MarkovSpline module with configuration."""
try:
self.markov_module = create_markov_spline_bitpipe_module(config)
print(f"βœ… Initialized MarkovSpline module: {self.markov_module.module_name}")
return True
except Exception as e:
print(f"❌ Failed to initialize MarkovSpline: {e}")
return False
def preprocess_text_data(self,
input_file: str,
output_file: str,
smoothing_strength: float = 0.15,
chunk_size: int = 128) -> bool:
"""Preprocess text data using MarkovSpline for BitTransformerLM training."""
if not self.markov_module:
print("❌ MarkovSpline module not initialized")
return False
try:
# Read input text
with open(input_file, 'r', encoding='utf-8') as f:
text_data = f.read().strip().split('\n')
print(f"πŸ“– Processing {len(text_data)} text samples...")
# Convert text to bit sequences
bit_sequences = []
for text in text_data:
if text.strip():
bits = self.text_converter.text_to_bits(text, max_length=chunk_size)
bit_sequences.append(bits)
print(f"πŸ”„ Converting to bit sequences: {len(bit_sequences)} sequences")
# Initialize MarkovSpline preprocessor
self.markov_module.initialize_application('data_preprocessor',
smoothing_strength=smoothing_strength,
preserve_features=True)
# Process bit sequences through MarkovSpline
result = self.markov_module.process_data(
bit_sequences,
'preprocess_training',
binary_data=True
)
if not result['success']:
print(f"❌ Processing failed: {result.get('error', 'Unknown error')}")
return False
# Save processed sequences
processed_data = {
'processed_sequences': result['processed_sequences'],
'preprocessing_summary': result['preprocessing_summary'],
'original_count': len(bit_sequences),
'smoothing_strength': smoothing_strength,
'chunk_size': chunk_size
}
with open(output_file, 'w') as f:
json.dump(processed_data, f, indent=2, default=str)
print(f"βœ… Preprocessed data saved to: {output_file}")
print(f"πŸ“Š Summary: {result['preprocessing_summary']}")
return True
except Exception as e:
print(f"❌ Preprocessing failed: {e}")
return False
def smooth_bit_sequence(self,
bit_sequence: List[int],
smoothing_type: str = 'predict_binary',
num_predictions: int = 10) -> Dict[str, Any]:
"""Smooth/predict bit sequence using MarkovSpline."""
if not self.markov_module:
print("❌ MarkovSpline module not initialized")
return {'success': False, 'error': 'Module not initialized'}
try:
result = self.markov_module.process_data(
bit_sequence,
smoothing_type,
num_predictions=num_predictions
)
return result
except Exception as e:
print(f"❌ Bit sequence processing failed: {e}")
return {'success': False, 'error': str(e)}
def smooth_training_gradients(self,
gradient_file: str,
output_file: str,
learning_rate: float = 0.01,
smoothing_strength: float = 0.2) -> bool:
"""Apply MarkovSpline gradient smoothing to BitTransformerLM training."""
if not self.markov_module:
print("❌ MarkovSpline module not initialized")
return False
try:
# Load gradient data (assuming PyTorch checkpoint format)
checkpoint = torch.load(gradient_file, map_location='cpu')
if 'gradients' not in checkpoint or 'parameters' not in checkpoint:
print("❌ Invalid gradient file format")
return False
# Initialize gradient smoother
self.markov_module.initialize_application('gradient_smoother',
learning_rate=learning_rate,
smoothing_strength=smoothing_strength)
# Process gradients
result = self.markov_module.process_data(
{
'parameters': checkpoint['parameters'],
'gradients': checkpoint['gradients']
},
'smooth_gradients'
)
if not result['success']:
print(f"❌ Gradient smoothing failed: {result.get('error', 'Unknown error')}")
return False
# Save smoothed parameters
smoothed_checkpoint = {
'smoothed_parameters': result['smoothed_parameters'],
'optimization_metrics': result['optimization_metrics'],
'original_gradients': checkpoint['gradients']
}
torch.save(smoothed_checkpoint, output_file)
print(f"βœ… Smoothed gradients saved to: {output_file}")
print(f"πŸ“Š Optimization metrics: {result['optimization_metrics']}")
return True
except Exception as e:
print(f"❌ Gradient smoothing failed: {e}")
return False
def create_smoothed_dataset(self,
input_dataset: str,
output_dataset: str,
config: Optional[Dict] = None) -> bool:
"""Create smoothed dataset for BitTransformerLM training."""
# Default configuration for dataset smoothing
default_config = {
'smoothing_strength': 0.1,
'num_states': 20,
'spline_type': 'cubic',
'preserve_features': True
}
if config:
default_config.update(config)
if not self.markov_module:
self.initialize_markov_spline(default_config)
return self.preprocess_text_data(input_dataset, output_dataset,
default_config['smoothing_strength'])
def main():
parser = argparse.ArgumentParser(description='MarkovSpline CLI for BitTransformerLM')
parser.add_argument('command', choices=['preprocess', 'smooth-gradients', 'create-dataset', 'predict-bits'],
help='Command to execute')
# Common arguments
parser.add_argument('--input', '-i', required=True, help='Input file path')
parser.add_argument('--output', '-o', required=True, help='Output file path')
parser.add_argument('--config', '-c', help='Configuration JSON file')
# Preprocessing arguments
parser.add_argument('--smoothing-strength', type=float, default=0.15,
help='Smoothing strength (0.0-1.0)')
parser.add_argument('--chunk-size', type=int, default=128,
help='Text chunk size for bit conversion')
# Gradient smoothing arguments
parser.add_argument('--learning-rate', type=float, default=0.01,
help='Learning rate for gradient smoothing')
# Bit prediction arguments
parser.add_argument('--num-predictions', type=int, default=10,
help='Number of bit predictions to generate')
args = parser.parse_args()
# Load configuration if provided
config = None
if args.config:
try:
with open(args.config, 'r') as f:
config = json.load(f)
except Exception as e:
print(f"❌ Failed to load config: {e}")
return 1
# Initialize CLI
cli = MarkovSplineBitTransformerCLI()
if not cli.initialize_markov_spline(config):
return 1
# Execute command
success = False
if args.command == 'preprocess':
success = cli.preprocess_text_data(
args.input, args.output,
args.smoothing_strength, args.chunk_size
)
elif args.command == 'smooth-gradients':
success = cli.smooth_training_gradients(
args.input, args.output,
args.learning_rate, args.smoothing_strength
)
elif args.command == 'create-dataset':
success = cli.create_smoothed_dataset(
args.input, args.output, config
)
elif args.command == 'predict-bits':
# Read bit sequence from input file
try:
with open(args.input, 'r') as f:
bit_data = json.load(f)
bit_sequence = bit_data.get('bits', [])
result = cli.smooth_bit_sequence(bit_sequence, 'predict_binary', args.num_predictions)
if result['success']:
with open(args.output, 'w') as f:
json.dump(result, f, indent=2, default=str)
print(f"βœ… Bit predictions saved to: {args.output}")
success = True
else:
print(f"❌ Bit prediction failed: {result.get('error', 'Unknown error')}")
except Exception as e:
print(f"❌ Bit prediction failed: {e}")
return 0 if success else 1
if __name__ == '__main__':
sys.exit(main())