|
|
|
|
|
""" |
|
|
Full Attention BitTransformerLM Diffusion Inference Test |
|
|
======================================================== |
|
|
|
|
|
Test the newly trained full bi-directional attention BitTransformerLM model |
|
|
using denoising diffusion generation to evaluate improvements from full attention training. |
|
|
|
|
|
Model Configuration: |
|
|
- Same full bi-directional unchunked attention as training (chunk_size=None) |
|
|
- Proper eval() mode with dropout management |
|
|
- Use latest checkpoint_best.pt from full attention training |
|
|
- Test with same diffusion inference that worked before |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from datetime import datetime |
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import ( |
|
|
BitTransformerLM, |
|
|
text_to_bits, |
|
|
bits_to_text, |
|
|
diffusion_inference, |
|
|
set_dropout |
|
|
) |
|
|
|
|
|
def load_full_attention_model(): |
|
|
"""Load the newly trained full attention BitTransformerLM model.""" |
|
|
print("π Loading Full Attention BitTransformerLM for diffusion inference...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=False, |
|
|
use_autocast=False, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05, |
|
|
chunk_size=None, |
|
|
overlap=0, |
|
|
full_attn_logging=True |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
set_dropout(model, 0.0) |
|
|
|
|
|
|
|
|
epoch = checkpoint.get('epoch', 'unknown') |
|
|
loss = checkpoint.get('loss', 'unknown') |
|
|
|
|
|
print(f"β
Full Attention Model loaded! Epoch: {epoch}, Loss: {loss}") |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"π Parameters: {total_params:,}") |
|
|
|
|
|
return model |
|
|
|
|
|
def test_basic_diffusion_generation(model): |
|
|
"""Test basic unconditional diffusion generation.""" |
|
|
print("\nπ§ͺ === BASIC FULL ATTENTION DIFFUSION GENERATION ===") |
|
|
|
|
|
results = [] |
|
|
|
|
|
test_configs = [ |
|
|
{"length": 36, "steps": 8, "schedule": "linear"}, |
|
|
{"length": 45, "steps": 12, "schedule": "cosine"}, |
|
|
{"length": 54, "steps": 16, "schedule": "exp"} |
|
|
] |
|
|
|
|
|
for i, config in enumerate(test_configs, 1): |
|
|
print(f"\n--- Test {i}: {config['length']//9} chars, {config['schedule']} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=config['length'], |
|
|
steps=config['steps'], |
|
|
batch_size=1, |
|
|
schedule=config['schedule'] |
|
|
) |
|
|
|
|
|
|
|
|
bit_list = generated_bits.squeeze().tolist() |
|
|
decoded_text = bits_to_text(bit_list) |
|
|
|
|
|
print(f"β
SUCCESS: '{decoded_text}'") |
|
|
results.append({ |
|
|
"test": f"basic_{i}", |
|
|
"config": config, |
|
|
"success": True, |
|
|
"output": decoded_text, |
|
|
"bits": len(bit_list) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β FAILED: {e}") |
|
|
results.append({ |
|
|
"test": f"basic_{i}", |
|
|
"config": config, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_conditioned_diffusion_generation(model): |
|
|
"""Test prompt-conditioned diffusion generation.""" |
|
|
print("\nπ― === CONDITIONED FULL ATTENTION DIFFUSION GENERATION ===") |
|
|
|
|
|
results = [] |
|
|
|
|
|
test_prompts = [ |
|
|
"Hello", |
|
|
"Hi there", |
|
|
"What is your name?", |
|
|
"The weather is", |
|
|
"I am", |
|
|
"Yes", |
|
|
"No" |
|
|
] |
|
|
|
|
|
for prompt in test_prompts: |
|
|
print(f"\n--- Prompt: '{prompt}' ---") |
|
|
|
|
|
try: |
|
|
|
|
|
prompt_bits = text_to_bits(prompt) |
|
|
|
|
|
|
|
|
continuation_length = 45 |
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=continuation_length, |
|
|
steps=12, |
|
|
batch_size=1, |
|
|
init_bits=None, |
|
|
schedule="cosine" |
|
|
) |
|
|
|
|
|
|
|
|
full_bits = prompt_bits + generated_bits.squeeze().tolist() |
|
|
|
|
|
|
|
|
continuation_bits = generated_bits.squeeze().tolist() |
|
|
continuation_text = bits_to_text(continuation_bits) |
|
|
|
|
|
|
|
|
combined_text = prompt + continuation_text |
|
|
print(f"β
SUCCESS: '{prompt}' β '{combined_text}'") |
|
|
results.append({ |
|
|
"test": "conditioned", |
|
|
"prompt": prompt, |
|
|
"success": True, |
|
|
"full_output": combined_text, |
|
|
"continuation": continuation_text, |
|
|
"bits": len(continuation_bits) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β FAILED: {e}") |
|
|
results.append({ |
|
|
"test": "conditioned", |
|
|
"prompt": prompt, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_code_diffusion_completion(model): |
|
|
"""Test code/math completion with diffusion.""" |
|
|
print("\nπ» === CODE COMPLETION FULL ATTENTION DIFFUSION ===") |
|
|
|
|
|
results = [] |
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
"2 + 2 =", |
|
|
"1 + 1 =", |
|
|
"5 * 3 =", |
|
|
"10 / 2 =", |
|
|
|
|
|
|
|
|
"def hello():", |
|
|
"if x ==", |
|
|
"for i in", |
|
|
"print(", |
|
|
"return", |
|
|
|
|
|
|
|
|
"a, b, c,", |
|
|
"1, 2, 3,", |
|
|
"function(", |
|
|
"var x =", |
|
|
] |
|
|
|
|
|
for code in test_cases: |
|
|
print(f"\n--- Code: '{code}' ---") |
|
|
|
|
|
try: |
|
|
|
|
|
code_bits = text_to_bits(code) |
|
|
|
|
|
|
|
|
completion_length = 45 |
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=completion_length, |
|
|
steps=10, |
|
|
batch_size=1, |
|
|
init_bits=None, |
|
|
schedule="linear" |
|
|
) |
|
|
|
|
|
|
|
|
completion_bits = generated_bits.squeeze().tolist() |
|
|
completion = bits_to_text(completion_bits) |
|
|
|
|
|
|
|
|
combined_text = code + completion |
|
|
print(f"β
SUCCESS: '{code}' β '{combined_text}'") |
|
|
|
|
|
|
|
|
analysis = [] |
|
|
if any(c.isalnum() for c in completion): |
|
|
analysis.append("Contains alphanumeric") |
|
|
print(f" π Analysis: Contains alphanumeric") |
|
|
if any(c in "0123456789" for c in completion): |
|
|
analysis.append("Contains numbers") |
|
|
print(f" π’ Analysis: Contains numbers") |
|
|
if any(c in "=(){}[];," for c in completion): |
|
|
analysis.append("Contains code symbols") |
|
|
print(f" π» Analysis: Contains code symbols") |
|
|
|
|
|
results.append({ |
|
|
"test": "code_completion", |
|
|
"prompt": code, |
|
|
"success": True, |
|
|
"full_output": combined_text, |
|
|
"completion": completion, |
|
|
"analysis": analysis, |
|
|
"bits": len(completion_bits) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β FAILED: {e}") |
|
|
results.append({ |
|
|
"test": "code_completion", |
|
|
"prompt": code, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def compare_with_previous_results(): |
|
|
"""Note about comparison with previous results.""" |
|
|
print("\nβοΈ === COMPARISON WITH PREVIOUS RESULTS ===") |
|
|
print("Previous chunked attention model achieved:") |
|
|
print("- Basic generation: 3/3 success (100%)") |
|
|
print("- Conditioned generation: 7/7 success (100%)") |
|
|
print("- Code completion: 13/13 success (100%)") |
|
|
print("- All diffusion inference succeeded vs 0% autoregressive") |
|
|
print("\nTesting if full attention training improved quality...") |
|
|
|
|
|
def main(): |
|
|
print("π FULL ATTENTION BITRANSFORMERLM DIFFUSION INFERENCE TEST") |
|
|
print("=" * 70) |
|
|
print("Testing newly trained full bi-directional attention model") |
|
|
print("with denoising diffusion generation") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
model = load_full_attention_model() |
|
|
|
|
|
|
|
|
basic_results = test_basic_diffusion_generation(model) |
|
|
conditioned_results = test_conditioned_diffusion_generation(model) |
|
|
code_results = test_code_diffusion_completion(model) |
|
|
|
|
|
|
|
|
compare_with_previous_results() |
|
|
|
|
|
|
|
|
total_tests = len(basic_results) + len(conditioned_results) + len(code_results) |
|
|
successful_tests = sum(1 for r in basic_results + conditioned_results + code_results if r.get('success', False)) |
|
|
success_rate = (successful_tests / total_tests) * 100 if total_tests > 0 else 0 |
|
|
|
|
|
print(f"\nπ― === FINAL SUMMARY ===") |
|
|
print(f"Total tests: {total_tests}") |
|
|
print(f"Successful: {successful_tests}") |
|
|
print(f"Success rate: {success_rate:.1f}%") |
|
|
|
|
|
print(f"\nBreakdown:") |
|
|
print(f"- Basic generation: {sum(1 for r in basic_results if r.get('success', False))}/{len(basic_results)}") |
|
|
print(f"- Conditioned generation: {sum(1 for r in conditioned_results if r.get('success', False))}/{len(conditioned_results)}") |
|
|
print(f"- Code completion: {sum(1 for r in code_results if r.get('success', False))}/{len(code_results)}") |
|
|
|
|
|
|
|
|
return { |
|
|
'basic_results': basic_results, |
|
|
'conditioned_results': conditioned_results, |
|
|
'code_results': code_results, |
|
|
'summary': { |
|
|
'total_tests': total_tests, |
|
|
'successful_tests': successful_tests, |
|
|
'success_rate': success_rate, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |