BitTransformerLM / scripts /testing /full_attention_inference_test.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
ba1fad5 verified
#!/usr/bin/env python3
"""
Full Attention BitTransformerLM Diffusion Inference Test
========================================================
Test the newly trained full bi-directional attention BitTransformerLM model
using denoising diffusion generation to evaluate improvements from full attention training.
Model Configuration:
- Same full bi-directional unchunked attention as training (chunk_size=None)
- Proper eval() mode with dropout management
- Use latest checkpoint_best.pt from full attention training
- Test with same diffusion inference that worked before
"""
import sys
import torch
import torch.nn.functional as F
from datetime import datetime
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import (
BitTransformerLM,
text_to_bits,
bits_to_text,
diffusion_inference,
set_dropout
)
def load_full_attention_model():
"""Load the newly trained full attention BitTransformerLM model."""
print("πŸš€ Loading Full Attention BitTransformerLM for diffusion inference...")
# Create model with SAME configuration as full attention training
model = BitTransformerLM(
d_model=512, # Same as training
nhead=16, # Same as training
num_layers=8, # Same as training
dim_feedforward=1024, # Same as training
max_seq_len=512, # Same as training
reversible=True, # Same as training
use_checkpoint=False, # Disable for inference
use_autocast=False, # Disable for inference
use_act=True, # Same as training
act_threshold=0.9, # Same as training
lambda_K=0.05, # Same as training
lambda_C=0.05, # Same as training
lambda_S=0.05, # Same as training
chunk_size=None, # FULL ATTENTION - same as training
overlap=0, # Same as training
full_attn_logging=True # Same as training
)
# Load the latest checkpoint_best.pt (should be from full attention training)
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
# Set to evaluation mode with proper dropout
model.eval()
set_dropout(model, 0.0) # Disable dropout for inference
# Get checkpoint info
epoch = checkpoint.get('epoch', 'unknown')
loss = checkpoint.get('loss', 'unknown')
print(f"βœ… Full Attention Model loaded! Epoch: {epoch}, Loss: {loss}")
# Calculate parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"πŸ“Š Parameters: {total_params:,}")
return model
def test_basic_diffusion_generation(model):
"""Test basic unconditional diffusion generation."""
print("\nπŸ§ͺ === BASIC FULL ATTENTION DIFFUSION GENERATION ===")
results = []
test_configs = [
{"length": 36, "steps": 8, "schedule": "linear"},
{"length": 45, "steps": 12, "schedule": "cosine"},
{"length": 54, "steps": 16, "schedule": "exp"}
]
for i, config in enumerate(test_configs, 1):
print(f"\n--- Test {i}: {config['length']//9} chars, {config['schedule']} ---")
try:
# Generate with diffusion
generated_bits = diffusion_inference(
model,
length=config['length'],
steps=config['steps'],
batch_size=1,
schedule=config['schedule']
)
# Try to decode
bit_list = generated_bits.squeeze().tolist()
decoded_text = bits_to_text(bit_list)
print(f"βœ… SUCCESS: '{decoded_text}'")
results.append({
"test": f"basic_{i}",
"config": config,
"success": True,
"output": decoded_text,
"bits": len(bit_list)
})
except Exception as e:
print(f"❌ FAILED: {e}")
results.append({
"test": f"basic_{i}",
"config": config,
"success": False,
"error": str(e)
})
return results
def test_conditioned_diffusion_generation(model):
"""Test prompt-conditioned diffusion generation."""
print("\n🎯 === CONDITIONED FULL ATTENTION DIFFUSION GENERATION ===")
results = []
test_prompts = [
"Hello",
"Hi there",
"What is your name?",
"The weather is",
"I am",
"Yes",
"No"
]
for prompt in test_prompts:
print(f"\n--- Prompt: '{prompt}' ---")
try:
# Convert prompt to bits
prompt_bits = text_to_bits(prompt)
# Generate continuation with diffusion (no init_bits - let it generate freely)
continuation_length = 45 # 5 character continuation
generated_bits = diffusion_inference(
model,
length=continuation_length,
steps=12,
batch_size=1,
init_bits=None,
schedule="cosine"
)
# Combine prompt + generated continuation
full_bits = prompt_bits + generated_bits.squeeze().tolist()
# Decode continuation only
continuation_bits = generated_bits.squeeze().tolist()
continuation_text = bits_to_text(continuation_bits)
# Show combined result
combined_text = prompt + continuation_text
print(f"βœ… SUCCESS: '{prompt}' β†’ '{combined_text}'")
results.append({
"test": "conditioned",
"prompt": prompt,
"success": True,
"full_output": combined_text,
"continuation": continuation_text,
"bits": len(continuation_bits)
})
except Exception as e:
print(f"❌ FAILED: {e}")
results.append({
"test": "conditioned",
"prompt": prompt,
"success": False,
"error": str(e)
})
return results
def test_code_diffusion_completion(model):
"""Test code/math completion with diffusion."""
print("\nπŸ’» === CODE COMPLETION FULL ATTENTION DIFFUSION ===")
results = []
test_cases = [
# Math equations
"2 + 2 =",
"1 + 1 =",
"5 * 3 =",
"10 / 2 =",
# Programming constructs
"def hello():",
"if x ==",
"for i in",
"print(",
"return",
# Patterns
"a, b, c,",
"1, 2, 3,",
"function(",
"var x =",
]
for code in test_cases:
print(f"\n--- Code: '{code}' ---")
try:
# Convert to bits
code_bits = text_to_bits(code)
# Generate completion with diffusion (no init_bits)
completion_length = 45 # 5 character completion
generated_bits = diffusion_inference(
model,
length=completion_length,
steps=10,
batch_size=1,
init_bits=None,
schedule="linear"
)
# Decode completion
completion_bits = generated_bits.squeeze().tolist()
completion = bits_to_text(completion_bits)
# Show combined result
combined_text = code + completion
print(f"βœ… SUCCESS: '{code}' β†’ '{combined_text}'")
# Analyze completion
analysis = []
if any(c.isalnum() for c in completion):
analysis.append("Contains alphanumeric")
print(f" πŸ“Š Analysis: Contains alphanumeric")
if any(c in "0123456789" for c in completion):
analysis.append("Contains numbers")
print(f" πŸ”’ Analysis: Contains numbers")
if any(c in "=(){}[];," for c in completion):
analysis.append("Contains code symbols")
print(f" πŸ’» Analysis: Contains code symbols")
results.append({
"test": "code_completion",
"prompt": code,
"success": True,
"full_output": combined_text,
"completion": completion,
"analysis": analysis,
"bits": len(completion_bits)
})
except Exception as e:
print(f"❌ FAILED: {e}")
results.append({
"test": "code_completion",
"prompt": code,
"success": False,
"error": str(e)
})
return results
def compare_with_previous_results():
"""Note about comparison with previous results."""
print("\nβš–οΈ === COMPARISON WITH PREVIOUS RESULTS ===")
print("Previous chunked attention model achieved:")
print("- Basic generation: 3/3 success (100%)")
print("- Conditioned generation: 7/7 success (100%)")
print("- Code completion: 13/13 success (100%)")
print("- All diffusion inference succeeded vs 0% autoregressive")
print("\nTesting if full attention training improved quality...")
def main():
print("πŸš€ FULL ATTENTION BITRANSFORMERLM DIFFUSION INFERENCE TEST")
print("=" * 70)
print("Testing newly trained full bi-directional attention model")
print("with denoising diffusion generation")
print("=" * 70)
# Load model
model = load_full_attention_model()
# Run tests
basic_results = test_basic_diffusion_generation(model)
conditioned_results = test_conditioned_diffusion_generation(model)
code_results = test_code_diffusion_completion(model)
# Show comparison
compare_with_previous_results()
# Calculate summary stats
total_tests = len(basic_results) + len(conditioned_results) + len(code_results)
successful_tests = sum(1 for r in basic_results + conditioned_results + code_results if r.get('success', False))
success_rate = (successful_tests / total_tests) * 100 if total_tests > 0 else 0
print(f"\n🎯 === FINAL SUMMARY ===")
print(f"Total tests: {total_tests}")
print(f"Successful: {successful_tests}")
print(f"Success rate: {success_rate:.1f}%")
print(f"\nBreakdown:")
print(f"- Basic generation: {sum(1 for r in basic_results if r.get('success', False))}/{len(basic_results)}")
print(f"- Conditioned generation: {sum(1 for r in conditioned_results if r.get('success', False))}/{len(conditioned_results)}")
print(f"- Code completion: {sum(1 for r in code_results if r.get('success', False))}/{len(code_results)}")
# Return all results for documentation
return {
'basic_results': basic_results,
'conditioned_results': conditioned_results,
'code_results': code_results,
'summary': {
'total_tests': total_tests,
'successful_tests': successful_tests,
'success_rate': success_rate,
'timestamp': datetime.now().isoformat()
}
}
if __name__ == "__main__":
results = main()