| |
| """ |
| Full Attention BitTransformerLM Diffusion Inference Test |
| ======================================================== |
| |
| Test the newly trained full bi-directional attention BitTransformerLM model |
| using denoising diffusion generation to evaluate improvements from full attention training. |
| |
| Model Configuration: |
| - Same full bi-directional unchunked attention as training (chunk_size=None) |
| - Proper eval() mode with dropout management |
| - Use latest checkpoint_best.pt from full attention training |
| - Test with same diffusion inference that worked before |
| """ |
|
|
| import sys |
| import torch |
| import torch.nn.functional as F |
| from datetime import datetime |
|
|
| sys.path.append('/data') |
| sys.path.append('/data/BitTransformerLM') |
|
|
| from bit_transformer import ( |
| BitTransformerLM, |
| text_to_bits, |
| bits_to_text, |
| diffusion_inference, |
| set_dropout |
| ) |
|
|
| def load_full_attention_model(): |
| """Load the newly trained full attention BitTransformerLM model.""" |
| print("π Loading Full Attention BitTransformerLM for diffusion inference...") |
| |
| |
| model = BitTransformerLM( |
| d_model=512, |
| nhead=16, |
| num_layers=8, |
| dim_feedforward=1024, |
| max_seq_len=512, |
| reversible=True, |
| use_checkpoint=False, |
| use_autocast=False, |
| use_act=True, |
| act_threshold=0.9, |
| lambda_K=0.05, |
| lambda_C=0.05, |
| lambda_S=0.05, |
| chunk_size=None, |
| overlap=0, |
| full_attn_logging=True |
| ) |
| |
| |
| checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| |
| model.eval() |
| set_dropout(model, 0.0) |
| |
| |
| epoch = checkpoint.get('epoch', 'unknown') |
| loss = checkpoint.get('loss', 'unknown') |
| |
| print(f"β
Full Attention Model loaded! Epoch: {epoch}, Loss: {loss}") |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"π Parameters: {total_params:,}") |
| |
| return model |
|
|
| def test_basic_diffusion_generation(model): |
| """Test basic unconditional diffusion generation.""" |
| print("\nπ§ͺ === BASIC FULL ATTENTION DIFFUSION GENERATION ===") |
| |
| results = [] |
| |
| test_configs = [ |
| {"length": 36, "steps": 8, "schedule": "linear"}, |
| {"length": 45, "steps": 12, "schedule": "cosine"}, |
| {"length": 54, "steps": 16, "schedule": "exp"} |
| ] |
| |
| for i, config in enumerate(test_configs, 1): |
| print(f"\n--- Test {i}: {config['length']//9} chars, {config['schedule']} ---") |
| |
| try: |
| |
| generated_bits = diffusion_inference( |
| model, |
| length=config['length'], |
| steps=config['steps'], |
| batch_size=1, |
| schedule=config['schedule'] |
| ) |
| |
| |
| bit_list = generated_bits.squeeze().tolist() |
| decoded_text = bits_to_text(bit_list) |
| |
| print(f"β
SUCCESS: '{decoded_text}'") |
| results.append({ |
| "test": f"basic_{i}", |
| "config": config, |
| "success": True, |
| "output": decoded_text, |
| "bits": len(bit_list) |
| }) |
| |
| except Exception as e: |
| print(f"β FAILED: {e}") |
| results.append({ |
| "test": f"basic_{i}", |
| "config": config, |
| "success": False, |
| "error": str(e) |
| }) |
| |
| return results |
|
|
| def test_conditioned_diffusion_generation(model): |
| """Test prompt-conditioned diffusion generation.""" |
| print("\nπ― === CONDITIONED FULL ATTENTION DIFFUSION GENERATION ===") |
| |
| results = [] |
| |
| test_prompts = [ |
| "Hello", |
| "Hi there", |
| "What is your name?", |
| "The weather is", |
| "I am", |
| "Yes", |
| "No" |
| ] |
| |
| for prompt in test_prompts: |
| print(f"\n--- Prompt: '{prompt}' ---") |
| |
| try: |
| |
| prompt_bits = text_to_bits(prompt) |
| |
| |
| continuation_length = 45 |
| generated_bits = diffusion_inference( |
| model, |
| length=continuation_length, |
| steps=12, |
| batch_size=1, |
| init_bits=None, |
| schedule="cosine" |
| ) |
| |
| |
| full_bits = prompt_bits + generated_bits.squeeze().tolist() |
| |
| |
| continuation_bits = generated_bits.squeeze().tolist() |
| continuation_text = bits_to_text(continuation_bits) |
| |
| |
| combined_text = prompt + continuation_text |
| print(f"β
SUCCESS: '{prompt}' β '{combined_text}'") |
| results.append({ |
| "test": "conditioned", |
| "prompt": prompt, |
| "success": True, |
| "full_output": combined_text, |
| "continuation": continuation_text, |
| "bits": len(continuation_bits) |
| }) |
| |
| except Exception as e: |
| print(f"β FAILED: {e}") |
| results.append({ |
| "test": "conditioned", |
| "prompt": prompt, |
| "success": False, |
| "error": str(e) |
| }) |
| |
| return results |
|
|
| def test_code_diffusion_completion(model): |
| """Test code/math completion with diffusion.""" |
| print("\nπ» === CODE COMPLETION FULL ATTENTION DIFFUSION ===") |
| |
| results = [] |
| |
| test_cases = [ |
| |
| "2 + 2 =", |
| "1 + 1 =", |
| "5 * 3 =", |
| "10 / 2 =", |
| |
| |
| "def hello():", |
| "if x ==", |
| "for i in", |
| "print(", |
| "return", |
| |
| |
| "a, b, c,", |
| "1, 2, 3,", |
| "function(", |
| "var x =", |
| ] |
| |
| for code in test_cases: |
| print(f"\n--- Code: '{code}' ---") |
| |
| try: |
| |
| code_bits = text_to_bits(code) |
| |
| |
| completion_length = 45 |
| generated_bits = diffusion_inference( |
| model, |
| length=completion_length, |
| steps=10, |
| batch_size=1, |
| init_bits=None, |
| schedule="linear" |
| ) |
| |
| |
| completion_bits = generated_bits.squeeze().tolist() |
| completion = bits_to_text(completion_bits) |
| |
| |
| combined_text = code + completion |
| print(f"β
SUCCESS: '{code}' β '{combined_text}'") |
| |
| |
| analysis = [] |
| if any(c.isalnum() for c in completion): |
| analysis.append("Contains alphanumeric") |
| print(f" π Analysis: Contains alphanumeric") |
| if any(c in "0123456789" for c in completion): |
| analysis.append("Contains numbers") |
| print(f" π’ Analysis: Contains numbers") |
| if any(c in "=(){}[];," for c in completion): |
| analysis.append("Contains code symbols") |
| print(f" π» Analysis: Contains code symbols") |
| |
| results.append({ |
| "test": "code_completion", |
| "prompt": code, |
| "success": True, |
| "full_output": combined_text, |
| "completion": completion, |
| "analysis": analysis, |
| "bits": len(completion_bits) |
| }) |
| |
| except Exception as e: |
| print(f"β FAILED: {e}") |
| results.append({ |
| "test": "code_completion", |
| "prompt": code, |
| "success": False, |
| "error": str(e) |
| }) |
| |
| return results |
|
|
| def compare_with_previous_results(): |
| """Note about comparison with previous results.""" |
| print("\nβοΈ === COMPARISON WITH PREVIOUS RESULTS ===") |
| print("Previous chunked attention model achieved:") |
| print("- Basic generation: 3/3 success (100%)") |
| print("- Conditioned generation: 7/7 success (100%)") |
| print("- Code completion: 13/13 success (100%)") |
| print("- All diffusion inference succeeded vs 0% autoregressive") |
| print("\nTesting if full attention training improved quality...") |
|
|
| def main(): |
| print("π FULL ATTENTION BITRANSFORMERLM DIFFUSION INFERENCE TEST") |
| print("=" * 70) |
| print("Testing newly trained full bi-directional attention model") |
| print("with denoising diffusion generation") |
| print("=" * 70) |
| |
| |
| model = load_full_attention_model() |
| |
| |
| basic_results = test_basic_diffusion_generation(model) |
| conditioned_results = test_conditioned_diffusion_generation(model) |
| code_results = test_code_diffusion_completion(model) |
| |
| |
| compare_with_previous_results() |
| |
| |
| total_tests = len(basic_results) + len(conditioned_results) + len(code_results) |
| successful_tests = sum(1 for r in basic_results + conditioned_results + code_results if r.get('success', False)) |
| success_rate = (successful_tests / total_tests) * 100 if total_tests > 0 else 0 |
| |
| print(f"\nπ― === FINAL SUMMARY ===") |
| print(f"Total tests: {total_tests}") |
| print(f"Successful: {successful_tests}") |
| print(f"Success rate: {success_rate:.1f}%") |
| |
| print(f"\nBreakdown:") |
| print(f"- Basic generation: {sum(1 for r in basic_results if r.get('success', False))}/{len(basic_results)}") |
| print(f"- Conditioned generation: {sum(1 for r in conditioned_results if r.get('success', False))}/{len(conditioned_results)}") |
| print(f"- Code completion: {sum(1 for r in code_results if r.get('success', False))}/{len(code_results)}") |
| |
| |
| return { |
| 'basic_results': basic_results, |
| 'conditioned_results': conditioned_results, |
| 'code_results': code_results, |
| 'summary': { |
| 'total_tests': total_tests, |
| 'successful_tests': successful_tests, |
| 'success_rate': success_rate, |
| 'timestamp': datetime.now().isoformat() |
| } |
| } |
|
|
| if __name__ == "__main__": |
| results = main() |