| | |
| | """ |
| | BitTransformerLM Denoising Diffusion Inference Tests |
| | ==================================================== |
| | |
| | Test the breakthrough model using built-in denoising diffusion generation |
| | to potentially resolve parity errors and improve text quality. |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import math |
| | import logging |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | def load_breakthrough_model(): |
| | """Load the trained breakthrough BitTransformerLM.""" |
| | print("π Loading breakthrough BitTransformerLM for diffusion inference...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=512, |
| | reversible=True, |
| | use_checkpoint=False, |
| | use_autocast=False, |
| | use_act=True, |
| | act_threshold=0.9, |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05 |
| | ) |
| | |
| | |
| | checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | print(f"β
Model loaded! Loss: {checkpoint['loss']:.6f}, Epoch: {checkpoint['epoch']}") |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"π Parameters: {total_params:,}") |
| | |
| | return model |
| |
|
| | def test_basic_diffusion_generation(model): |
| | """Test basic diffusion generation without conditioning.""" |
| | print("\nπ§ͺ === BASIC DIFFUSION GENERATION TESTS ===") |
| | |
| | test_configs = [ |
| | {"length": 36, "steps": 8, "schedule": "linear", "name": "4 chars, linear"}, |
| | {"length": 45, "steps": 12, "schedule": "cosine", "name": "5 chars, cosine"}, |
| | {"length": 54, "steps": 16, "schedule": "exp", "name": "6 chars, exp"}, |
| | ] |
| | |
| | results = [] |
| | |
| | for config in test_configs: |
| | print(f"\n--- {config['name']} ---") |
| | print(f"Config: {config['length']} bits, {config['steps']} steps, {config['schedule']} schedule") |
| | |
| | try: |
| | |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=config['length'], |
| | steps=config['steps'], |
| | schedule=config['schedule'] |
| | ) |
| | |
| | |
| | bits_list = generated_bits.squeeze().tolist() |
| | print(f"Generated {len(bits_list)} bits: {bits_list[:18]}...") |
| | |
| | |
| | try: |
| | text = bits_to_text(bits_list) |
| | print(f"β
SUCCESS: '{text}'") |
| | results.append({"config": config, "text": text, "success": True}) |
| | except Exception as decode_error: |
| | print(f"β Decode failed: {decode_error}") |
| | |
| | |
| | manual_text = "" |
| | for i in range(0, len(bits_list), 9): |
| | if i + 8 < len(bits_list): |
| | char_bits = bits_list[i:i+8] |
| | byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
| | if 32 <= byte_val <= 126: |
| | manual_text += chr(byte_val) |
| | else: |
| | manual_text += '?' |
| | |
| | print(f"π§ Manual decode: '{manual_text}'") |
| | results.append({"config": config, "text": manual_text, "success": False}) |
| | |
| | except Exception as e: |
| | print(f"π₯ Generation failed: {e}") |
| | results.append({"config": config, "text": None, "success": False, "error": str(e)}) |
| | |
| | return results |
| |
|
| | def test_conditioned_diffusion_generation(model): |
| | """Test diffusion generation conditioned on prompts.""" |
| | print("\nπ― === CONDITIONED DIFFUSION GENERATION TESTS ===") |
| | |
| | prompts = [ |
| | "Hello", |
| | "Hi there", |
| | "What is your name?", |
| | "The weather is", |
| | "I am", |
| | "Yes", |
| | "No" |
| | ] |
| | |
| | results = [] |
| | |
| | for prompt in prompts: |
| | print(f"\n--- Prompt: '{prompt}' ---") |
| | |
| | |
| | prompt_bits = text_to_bits(prompt) |
| | print(f"Prompt: {len(prompt_bits)} bits") |
| | |
| | |
| | total_length = len(prompt_bits) + 45 |
| | |
| | |
| | init_bits = torch.zeros(1, total_length, dtype=torch.long) |
| | init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
| | init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (total_length - len(prompt_bits),)) |
| | |
| | try: |
| | |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=total_length, |
| | steps=12, |
| | init_bits=init_bits, |
| | schedule="cosine" |
| | ) |
| | |
| | |
| | full_bits = generated_bits.squeeze().tolist() |
| | generated_only = full_bits[len(prompt_bits):] |
| | |
| | print(f"Generated {len(generated_only)} bits for continuation") |
| | |
| | |
| | try: |
| | continuation = bits_to_text(generated_only) |
| | full_result = prompt + continuation |
| | print(f"β
SUCCESS: '{prompt}' β '{full_result}'") |
| | results.append({ |
| | "prompt": prompt, |
| | "continuation": continuation, |
| | "full_result": full_result, |
| | "success": True |
| | }) |
| | except Exception as decode_error: |
| | print(f"β Decode failed: {decode_error}") |
| | |
| | |
| | manual_continuation = "" |
| | for i in range(0, len(generated_only), 9): |
| | if i + 8 < len(generated_only): |
| | char_bits = generated_only[i:i+8] |
| | byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
| | if 32 <= byte_val <= 126: |
| | manual_continuation += chr(byte_val) |
| | else: |
| | manual_continuation += '?' |
| | |
| | full_result = prompt + manual_continuation |
| | print(f"π§ Manual decode: '{prompt}' β '{full_result}'") |
| | results.append({ |
| | "prompt": prompt, |
| | "continuation": manual_continuation, |
| | "full_result": full_result, |
| | "success": False |
| | }) |
| | |
| | except Exception as e: |
| | print(f"π₯ Generation failed: {e}") |
| | results.append({ |
| | "prompt": prompt, |
| | "continuation": None, |
| | "full_result": None, |
| | "success": False, |
| | "error": str(e) |
| | }) |
| | |
| | return results |
| |
|
| | def test_code_diffusion_completion(model): |
| | """Test diffusion generation on code/math completion.""" |
| | print("\nπ» === CODE DIFFUSION COMPLETION TESTS ===") |
| | |
| | code_prompts = [ |
| | |
| | "2 + 2 =", |
| | "1 + 1 =", |
| | "5 * 3 =", |
| | "10 / 2 =", |
| | |
| | |
| | "def hello():", |
| | "if x ==", |
| | "for i in", |
| | "print(", |
| | "return", |
| | |
| | |
| | "a, b, c,", |
| | "1, 2, 3,", |
| | "function(", |
| | "var x =", |
| | ] |
| | |
| | results = [] |
| | |
| | for prompt in code_prompts: |
| | print(f"\n--- Code: '{prompt}' ---") |
| | |
| | prompt_bits = text_to_bits(prompt) |
| | print(f"Prompt: {len(prompt_bits)} bits") |
| | |
| | |
| | completion_length = 36 |
| | total_length = len(prompt_bits) + completion_length |
| | |
| | |
| | init_bits = torch.zeros(1, total_length, dtype=torch.long) |
| | init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
| | init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (completion_length,)) |
| | |
| | try: |
| | |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=total_length, |
| | steps=16, |
| | init_bits=init_bits, |
| | schedule="exp" |
| | ) |
| | |
| | |
| | full_bits = generated_bits.squeeze().tolist() |
| | completion_bits = full_bits[len(prompt_bits):] |
| | |
| | |
| | try: |
| | completion = bits_to_text(completion_bits) |
| | full_result = prompt + completion |
| | print(f"β
SUCCESS: '{prompt}' β '{full_result}'") |
| | |
| | |
| | analysis = [] |
| | if any(c.isalnum() for c in completion): |
| | analysis.append("Contains alphanumeric") |
| | if any(c in "0123456789" for c in completion): |
| | analysis.append("Contains numbers") |
| | if any(c in "=(){}[];," for c in completion): |
| | analysis.append("Contains code symbols") |
| | if any(c in " \n\t" for c in completion): |
| | analysis.append("Contains whitespace") |
| | |
| | if analysis: |
| | print(f" π Analysis: {', '.join(analysis)}") |
| | |
| | results.append({ |
| | "prompt": prompt, |
| | "completion": completion, |
| | "full_result": full_result, |
| | "analysis": analysis, |
| | "success": True |
| | }) |
| | |
| | except Exception as decode_error: |
| | print(f"β Decode failed: {decode_error}") |
| | |
| | |
| | manual_completion = "" |
| | char_types = {"letters": 0, "numbers": 0, "symbols": 0, "printable": 0} |
| | |
| | for i in range(0, len(completion_bits), 9): |
| | if i + 8 < len(completion_bits): |
| | char_bits = completion_bits[i:i+8] |
| | byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
| | if 32 <= byte_val <= 126: |
| | char = chr(byte_val) |
| | manual_completion += char |
| | char_types["printable"] += 1 |
| | if char.isalpha(): |
| | char_types["letters"] += 1 |
| | elif char.isdigit(): |
| | char_types["numbers"] += 1 |
| | elif char in "=(){}[];,+-*/<>!@#$%^&": |
| | char_types["symbols"] += 1 |
| | else: |
| | manual_completion += '?' |
| | |
| | full_result = prompt + manual_completion |
| | print(f"π§ Manual decode: '{prompt}' β '{full_result}'") |
| | print(f" π Character types: {char_types}") |
| | |
| | results.append({ |
| | "prompt": prompt, |
| | "completion": manual_completion, |
| | "full_result": full_result, |
| | "char_types": char_types, |
| | "success": False |
| | }) |
| | |
| | except Exception as e: |
| | print(f"π₯ Generation failed: {e}") |
| | results.append({ |
| | "prompt": prompt, |
| | "completion": None, |
| | "full_result": None, |
| | "success": False, |
| | "error": str(e) |
| | }) |
| | |
| | return results |
| |
|
| | def compare_diffusion_vs_autoregressive(model): |
| | """Compare diffusion vs autoregressive generation quality.""" |
| | print("\nβοΈ === DIFFUSION vs AUTOREGRESSIVE COMPARISON ===") |
| | |
| | test_prompts = ["Hello", "Hi", "The cat", "Yes"] |
| | comparison_results = [] |
| | |
| | for prompt in test_prompts: |
| | print(f"\n--- Comparing generation for: '{prompt}' ---") |
| | |
| | prompt_bits = text_to_bits(prompt) |
| | generation_length = 27 |
| | |
| | |
| | print("π Autoregressive generation:") |
| | try: |
| | generated_bits_ar = prompt_bits.copy() |
| | |
| | with torch.no_grad(): |
| | for i in range(generation_length): |
| | context = generated_bits_ar[-300:] if len(generated_bits_ar) > 300 else generated_bits_ar |
| | context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
| | |
| | logits, _ = model(context_tensor) |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | |
| | next_bit_logits = next_bit_logits / 0.8 |
| | probs = torch.softmax(next_bit_logits, dim=-1) |
| | next_bit = torch.multinomial(probs, 1).item() |
| | |
| | generated_bits_ar.append(next_bit) |
| | |
| | ar_completion_bits = generated_bits_ar[len(prompt_bits):] |
| | try: |
| | ar_completion = bits_to_text(ar_completion_bits) |
| | ar_success = True |
| | except: |
| | ar_completion = "DECODE_FAILED" |
| | ar_success = False |
| | |
| | print(f" Result: '{prompt}' β '{prompt + ar_completion}' (Success: {ar_success})") |
| | |
| | except Exception as e: |
| | ar_completion = f"ERROR: {e}" |
| | ar_success = False |
| | print(f" Result: ERROR - {e}") |
| | |
| | |
| | print("π Diffusion generation:") |
| | try: |
| | total_length = len(prompt_bits) + generation_length |
| | init_bits = torch.zeros(1, total_length, dtype=torch.long) |
| | init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
| | init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (generation_length,)) |
| | |
| | generated_bits_diff = diffusion_inference( |
| | model, |
| | length=total_length, |
| | steps=12, |
| | init_bits=init_bits, |
| | schedule="cosine" |
| | ) |
| | |
| | diff_completion_bits = generated_bits_diff.squeeze().tolist()[len(prompt_bits):] |
| | try: |
| | diff_completion = bits_to_text(diff_completion_bits) |
| | diff_success = True |
| | except: |
| | diff_completion = "DECODE_FAILED" |
| | diff_success = False |
| | |
| | print(f" Result: '{prompt}' β '{prompt + diff_completion}' (Success: {diff_success})") |
| | |
| | except Exception as e: |
| | diff_completion = f"ERROR: {e}" |
| | diff_success = False |
| | print(f" Result: ERROR - {e}") |
| | |
| | |
| | comparison_results.append({ |
| | "prompt": prompt, |
| | "autoregressive": {"completion": ar_completion, "success": ar_success}, |
| | "diffusion": {"completion": diff_completion, "success": diff_success} |
| | }) |
| | |
| | |
| | if diff_success and ar_success: |
| | print(f" π Both methods succeeded!") |
| | elif diff_success and not ar_success: |
| | print(f" π Diffusion wins - only it succeeded!") |
| | elif ar_success and not diff_success: |
| | print(f" π Autoregressive wins - only it succeeded!") |
| | else: |
| | print(f" π Both methods failed") |
| | |
| | return comparison_results |
| |
|
| | def main(): |
| | """Run all diffusion inference tests.""" |
| | print("π BITRANSFORMERLM DENOISING DIFFUSION INFERENCE TESTS") |
| | print("=" * 70) |
| | print("Testing hypothesis: Denoising diffusion should reduce parity errors") |
| | print("by treating parity bits as noise and filtering them out.") |
| | print("=" * 70) |
| | |
| | |
| | model = load_breakthrough_model() |
| | |
| | |
| | test_results = { |
| | "basic_diffusion": test_basic_diffusion_generation(model), |
| | "conditioned_diffusion": test_conditioned_diffusion_generation(model), |
| | "code_diffusion": test_code_diffusion_completion(model), |
| | "comparison": compare_diffusion_vs_autoregressive(model), |
| | } |
| | |
| | print("\nπ― === FINAL SUMMARY ===") |
| | |
| | |
| | basic_successes = sum(1 for r in test_results["basic_diffusion"] if r.get("success", False)) |
| | print(f"Basic diffusion success rate: {basic_successes}/{len(test_results['basic_diffusion'])}") |
| | |
| | |
| | cond_successes = sum(1 for r in test_results["conditioned_diffusion"] if r.get("success", False)) |
| | print(f"Conditioned diffusion success rate: {cond_successes}/{len(test_results['conditioned_diffusion'])}") |
| | |
| | |
| | code_successes = sum(1 for r in test_results["code_diffusion"] if r.get("success", False)) |
| | print(f"Code diffusion success rate: {code_successes}/{len(test_results['code_diffusion'])}") |
| | |
| | |
| | diff_wins = sum(1 for r in test_results["comparison"] |
| | if r["diffusion"]["success"] and not r["autoregressive"]["success"]) |
| | ar_wins = sum(1 for r in test_results["comparison"] |
| | if r["autoregressive"]["success"] and not r["diffusion"]["success"]) |
| | both_win = sum(1 for r in test_results["comparison"] |
| | if r["diffusion"]["success"] and r["autoregressive"]["success"]) |
| | |
| | print(f"Method comparison - Diffusion only: {diff_wins}, Autoregressive only: {ar_wins}, Both: {both_win}") |
| | |
| | return test_results |
| |
|
| | if __name__ == "__main__": |
| | main() |