| | |
| | """ |
| | Full Attention BitTransformerLM Diffusion Inference Test |
| | ======================================================== |
| | |
| | Test the newly trained full bi-directional attention BitTransformerLM model |
| | using denoising diffusion generation to evaluate improvements from full attention training. |
| | |
| | Model Configuration: |
| | - Same full bi-directional unchunked attention as training (chunk_size=None) |
| | - Proper eval() mode with dropout management |
| | - Use latest checkpoint_best.pt from full attention training |
| | - Test with same diffusion inference that worked before |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| | from datetime import datetime |
| |
|
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | text_to_bits, |
| | bits_to_text, |
| | diffusion_inference, |
| | set_dropout |
| | ) |
| |
|
| | def load_full_attention_model(): |
| | """Load the newly trained full attention BitTransformerLM model.""" |
| | print("π Loading Full Attention BitTransformerLM for diffusion inference...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=512, |
| | reversible=True, |
| | use_checkpoint=False, |
| | use_autocast=False, |
| | use_act=True, |
| | act_threshold=0.9, |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05, |
| | chunk_size=None, |
| | overlap=0, |
| | full_attn_logging=True |
| | ) |
| | |
| | |
| | checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | |
| | |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | |
| | |
| | epoch = checkpoint.get('epoch', 'unknown') |
| | loss = checkpoint.get('loss', 'unknown') |
| | |
| | print(f"β
Full Attention Model loaded! Epoch: {epoch}, Loss: {loss}") |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"π Parameters: {total_params:,}") |
| | |
| | return model |
| |
|
| | def test_basic_diffusion_generation(model): |
| | """Test basic unconditional diffusion generation.""" |
| | print("\nπ§ͺ === BASIC FULL ATTENTION DIFFUSION GENERATION ===") |
| | |
| | results = [] |
| | |
| | test_configs = [ |
| | {"length": 36, "steps": 8, "schedule": "linear"}, |
| | {"length": 45, "steps": 12, "schedule": "cosine"}, |
| | {"length": 54, "steps": 16, "schedule": "exp"} |
| | ] |
| | |
| | for i, config in enumerate(test_configs, 1): |
| | print(f"\n--- Test {i}: {config['length']//9} chars, {config['schedule']} ---") |
| | |
| | try: |
| | |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=config['length'], |
| | steps=config['steps'], |
| | batch_size=1, |
| | schedule=config['schedule'] |
| | ) |
| | |
| | |
| | bit_list = generated_bits.squeeze().tolist() |
| | decoded_text = bits_to_text(bit_list) |
| | |
| | print(f"β
SUCCESS: '{decoded_text}'") |
| | results.append({ |
| | "test": f"basic_{i}", |
| | "config": config, |
| | "success": True, |
| | "output": decoded_text, |
| | "bits": len(bit_list) |
| | }) |
| | |
| | except Exception as e: |
| | print(f"β FAILED: {e}") |
| | results.append({ |
| | "test": f"basic_{i}", |
| | "config": config, |
| | "success": False, |
| | "error": str(e) |
| | }) |
| | |
| | return results |
| |
|
| | def test_conditioned_diffusion_generation(model): |
| | """Test prompt-conditioned diffusion generation.""" |
| | print("\nπ― === CONDITIONED FULL ATTENTION DIFFUSION GENERATION ===") |
| | |
| | results = [] |
| | |
| | test_prompts = [ |
| | "Hello", |
| | "Hi there", |
| | "What is your name?", |
| | "The weather is", |
| | "I am", |
| | "Yes", |
| | "No" |
| | ] |
| | |
| | for prompt in test_prompts: |
| | print(f"\n--- Prompt: '{prompt}' ---") |
| | |
| | try: |
| | |
| | prompt_bits = text_to_bits(prompt) |
| | |
| | |
| | continuation_length = 45 |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=continuation_length, |
| | steps=12, |
| | batch_size=1, |
| | init_bits=None, |
| | schedule="cosine" |
| | ) |
| | |
| | |
| | full_bits = prompt_bits + generated_bits.squeeze().tolist() |
| | |
| | |
| | continuation_bits = generated_bits.squeeze().tolist() |
| | continuation_text = bits_to_text(continuation_bits) |
| | |
| | |
| | combined_text = prompt + continuation_text |
| | print(f"β
SUCCESS: '{prompt}' β '{combined_text}'") |
| | results.append({ |
| | "test": "conditioned", |
| | "prompt": prompt, |
| | "success": True, |
| | "full_output": combined_text, |
| | "continuation": continuation_text, |
| | "bits": len(continuation_bits) |
| | }) |
| | |
| | except Exception as e: |
| | print(f"β FAILED: {e}") |
| | results.append({ |
| | "test": "conditioned", |
| | "prompt": prompt, |
| | "success": False, |
| | "error": str(e) |
| | }) |
| | |
| | return results |
| |
|
| | def test_code_diffusion_completion(model): |
| | """Test code/math completion with diffusion.""" |
| | print("\nπ» === CODE COMPLETION FULL ATTENTION DIFFUSION ===") |
| | |
| | results = [] |
| | |
| | test_cases = [ |
| | |
| | "2 + 2 =", |
| | "1 + 1 =", |
| | "5 * 3 =", |
| | "10 / 2 =", |
| | |
| | |
| | "def hello():", |
| | "if x ==", |
| | "for i in", |
| | "print(", |
| | "return", |
| | |
| | |
| | "a, b, c,", |
| | "1, 2, 3,", |
| | "function(", |
| | "var x =", |
| | ] |
| | |
| | for code in test_cases: |
| | print(f"\n--- Code: '{code}' ---") |
| | |
| | try: |
| | |
| | code_bits = text_to_bits(code) |
| | |
| | |
| | completion_length = 45 |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=completion_length, |
| | steps=10, |
| | batch_size=1, |
| | init_bits=None, |
| | schedule="linear" |
| | ) |
| | |
| | |
| | completion_bits = generated_bits.squeeze().tolist() |
| | completion = bits_to_text(completion_bits) |
| | |
| | |
| | combined_text = code + completion |
| | print(f"β
SUCCESS: '{code}' β '{combined_text}'") |
| | |
| | |
| | analysis = [] |
| | if any(c.isalnum() for c in completion): |
| | analysis.append("Contains alphanumeric") |
| | print(f" π Analysis: Contains alphanumeric") |
| | if any(c in "0123456789" for c in completion): |
| | analysis.append("Contains numbers") |
| | print(f" π’ Analysis: Contains numbers") |
| | if any(c in "=(){}[];," for c in completion): |
| | analysis.append("Contains code symbols") |
| | print(f" π» Analysis: Contains code symbols") |
| | |
| | results.append({ |
| | "test": "code_completion", |
| | "prompt": code, |
| | "success": True, |
| | "full_output": combined_text, |
| | "completion": completion, |
| | "analysis": analysis, |
| | "bits": len(completion_bits) |
| | }) |
| | |
| | except Exception as e: |
| | print(f"β FAILED: {e}") |
| | results.append({ |
| | "test": "code_completion", |
| | "prompt": code, |
| | "success": False, |
| | "error": str(e) |
| | }) |
| | |
| | return results |
| |
|
| | def compare_with_previous_results(): |
| | """Note about comparison with previous results.""" |
| | print("\nβοΈ === COMPARISON WITH PREVIOUS RESULTS ===") |
| | print("Previous chunked attention model achieved:") |
| | print("- Basic generation: 3/3 success (100%)") |
| | print("- Conditioned generation: 7/7 success (100%)") |
| | print("- Code completion: 13/13 success (100%)") |
| | print("- All diffusion inference succeeded vs 0% autoregressive") |
| | print("\nTesting if full attention training improved quality...") |
| |
|
| | def main(): |
| | print("π FULL ATTENTION BITRANSFORMERLM DIFFUSION INFERENCE TEST") |
| | print("=" * 70) |
| | print("Testing newly trained full bi-directional attention model") |
| | print("with denoising diffusion generation") |
| | print("=" * 70) |
| | |
| | |
| | model = load_full_attention_model() |
| | |
| | |
| | basic_results = test_basic_diffusion_generation(model) |
| | conditioned_results = test_conditioned_diffusion_generation(model) |
| | code_results = test_code_diffusion_completion(model) |
| | |
| | |
| | compare_with_previous_results() |
| | |
| | |
| | total_tests = len(basic_results) + len(conditioned_results) + len(code_results) |
| | successful_tests = sum(1 for r in basic_results + conditioned_results + code_results if r.get('success', False)) |
| | success_rate = (successful_tests / total_tests) * 100 if total_tests > 0 else 0 |
| | |
| | print(f"\nπ― === FINAL SUMMARY ===") |
| | print(f"Total tests: {total_tests}") |
| | print(f"Successful: {successful_tests}") |
| | print(f"Success rate: {success_rate:.1f}%") |
| | |
| | print(f"\nBreakdown:") |
| | print(f"- Basic generation: {sum(1 for r in basic_results if r.get('success', False))}/{len(basic_results)}") |
| | print(f"- Conditioned generation: {sum(1 for r in conditioned_results if r.get('success', False))}/{len(conditioned_results)}") |
| | print(f"- Code completion: {sum(1 for r in code_results if r.get('success', False))}/{len(code_results)}") |
| | |
| | |
| | return { |
| | 'basic_results': basic_results, |
| | 'conditioned_results': conditioned_results, |
| | 'code_results': code_results, |
| | 'summary': { |
| | 'total_tests': total_tests, |
| | 'successful_tests': successful_tests, |
| | 'success_rate': success_rate, |
| | 'timestamp': datetime.now().isoformat() |
| | } |
| | } |
| |
|
| | if __name__ == "__main__": |
| | results = main() |