| | |
| | """ |
| | Enhanced BitTransformerLM Generation Testing |
| | ============================================= |
| | |
| | Test the promising generation improvements: |
| | 1. Autoregressive generation with automatic parity correction |
| | 2. Longer sequence generation (50, 100, 200+ characters) |
| | 3. Optimized diffusion parameters (50+ steps) |
| | 4. Direct comparison between generation methods |
| | |
| | Goal: See if we can get from "barely-contextual gibberish" to actual language! |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| | from datetime import datetime |
| |
|
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | text_to_bits, |
| | bits_to_text, |
| | diffusion_inference, |
| | set_dropout, |
| | enforce_parity |
| | ) |
| |
|
| | def load_full_attention_model(): |
| | """Load the full attention BitTransformerLM model.""" |
| | print("π Loading Full Attention BitTransformerLM for enhanced generation testing...") |
| | |
| | model = BitTransformerLM( |
| | d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
| | max_seq_len=512, reversible=True, use_checkpoint=False, |
| | use_autocast=False, use_act=True, act_threshold=0.9, |
| | lambda_K=0.05, lambda_C=0.05, lambda_S=0.05, |
| | chunk_size=None, overlap=0, full_attn_logging=True |
| | ) |
| | |
| | checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | |
| | epoch = checkpoint.get('epoch', 'unknown') |
| | loss = checkpoint.get('loss', 'unknown') |
| | print(f"β
Model loaded! Epoch: {epoch}, Loss: {loss}") |
| | |
| | return model |
| |
|
| | def autoregressive_generate_with_parity_correction(model, prompt, max_new_chars=20, temperature=0.7): |
| | """ |
| | Autoregressive generation with automatic parity correction. |
| | This should solve the parity check failure issue that blocked autoregressive evaluation. |
| | """ |
| | print(f"\nπ Autoregressive generation with parity correction:") |
| | print(f" Prompt: '{prompt}' β generating {max_new_chars} characters...") |
| | |
| | |
| | input_bits = text_to_bits(prompt) |
| | generated_bits = input_bits.copy() |
| | |
| | with torch.no_grad(): |
| | for char_idx in range(max_new_chars): |
| | char_bits = [] |
| | |
| | |
| | for bit_idx in range(9): |
| | |
| | context = generated_bits + char_bits |
| | context = context[-400:] if len(context) > 400 else context |
| | context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
| | |
| | |
| | logits, telemetry = model(context_tensor, causal=True) |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | if bit_idx < 8: |
| | |
| | if temperature > 0: |
| | next_bit_logits = next_bit_logits / temperature |
| | probs = F.softmax(next_bit_logits, dim=-1) |
| | next_bit = torch.multinomial(probs, 1).item() |
| | else: |
| | next_bit = torch.argmax(next_bit_logits).item() |
| | else: |
| | data_bits = char_bits[:8] |
| | expected_parity = sum(data_bits) % 2 |
| | next_bit = expected_parity |
| | |
| | char_bits.append(next_bit) |
| | |
| | |
| | generated_bits.extend(char_bits) |
| | |
| | |
| | new_bits = generated_bits[len(input_bits):] |
| | |
| | |
| | new_bits_tensor = torch.tensor(new_bits, dtype=torch.long) |
| | corrected_bits_tensor, parity_corrections = enforce_parity(new_bits_tensor) |
| | corrected_bits = corrected_bits_tensor.tolist() |
| | |
| | try: |
| | |
| | decoded_text = bits_to_text(corrected_bits) |
| | full_result = prompt + decoded_text |
| | print(f" β
SUCCESS: '{full_result}'") |
| | return { |
| | 'success': True, |
| | 'full_text': full_result, |
| | 'new_text': decoded_text, |
| | 'bits_generated': len(new_bits), |
| | 'parity_corrections': parity_corrections |
| | } |
| | except Exception as e: |
| | print(f" β DECODE FAILED: {e}") |
| | return { |
| | 'success': False, |
| | 'error': str(e), |
| | 'bits_generated': len(new_bits) |
| | } |
| |
|
| | def long_diffusion_generation(model, prompt, target_chars, steps=50): |
| | """ |
| | Generate longer sequences with optimized diffusion parameters. |
| | """ |
| | print(f"\nπ Long diffusion generation:") |
| | print(f" Prompt: '{prompt}' β generating {target_chars} characters with {steps} steps...") |
| | |
| | try: |
| | |
| | continuation_bits = target_chars * 9 |
| | generated_bits = diffusion_inference( |
| | model, |
| | length=continuation_bits, |
| | steps=steps, |
| | batch_size=1, |
| | init_bits=None, |
| | schedule="cosine" |
| | ) |
| | |
| | |
| | continuation_bits_list = generated_bits.squeeze().tolist() |
| | continuation_text = bits_to_text(continuation_bits_list) |
| | |
| | full_result = prompt + continuation_text |
| | print(f" β
SUCCESS: '{full_result}'") |
| | |
| | return { |
| | 'success': True, |
| | 'full_text': full_result, |
| | 'new_text': continuation_text, |
| | 'bits_generated': len(continuation_bits_list), |
| | 'diffusion_steps': steps |
| | } |
| | |
| | except Exception as e: |
| | print(f" β FAILED: {e}") |
| | return { |
| | 'success': False, |
| | 'error': str(e), |
| | 'diffusion_steps': steps |
| | } |
| |
|
| | def test_length_scaling(): |
| | """Test if longer generations produce more coherent results.""" |
| | print("\nπ === LENGTH SCALING TESTS ===") |
| | print("Testing if longer generations show improved coherence...") |
| | |
| | model = load_full_attention_model() |
| | test_prompts = ["Hello", "The weather today", "I think that"] |
| | target_lengths = [10, 25, 50] |
| | |
| | results = [] |
| | |
| | for prompt in test_prompts: |
| | for length in target_lengths: |
| | print(f"\n--- Testing '{prompt}' β {length} chars ---") |
| | |
| | |
| | auto_result = autoregressive_generate_with_parity_correction( |
| | model, prompt, max_new_chars=length, temperature=0.6 |
| | ) |
| | |
| | |
| | diff_result = long_diffusion_generation( |
| | model, prompt, target_chars=length, steps=50 |
| | ) |
| | |
| | results.append({ |
| | 'prompt': prompt, |
| | 'target_length': length, |
| | 'autoregressive': auto_result, |
| | 'diffusion': diff_result |
| | }) |
| | |
| | return results |
| |
|
| | def test_parameter_optimization(): |
| | """Test different generation parameters for quality.""" |
| | print("\nβοΈ === PARAMETER OPTIMIZATION TESTS ===") |
| | print("Testing different temperatures and diffusion steps...") |
| | |
| | model = load_full_attention_model() |
| | prompt = "Hello world" |
| | |
| | results = [] |
| | |
| | |
| | print("\nπ‘οΈ Testing autoregressive temperatures:") |
| | for temp in [0.1, 0.5, 0.8, 1.0, 1.2]: |
| | print(f"\n--- Temperature {temp} ---") |
| | result = autoregressive_generate_with_parity_correction( |
| | model, prompt, max_new_chars=20, temperature=temp |
| | ) |
| | results.append({ |
| | 'method': 'autoregressive', |
| | 'temperature': temp, |
| | 'result': result |
| | }) |
| | |
| | |
| | print("\nπ Testing diffusion steps:") |
| | for steps in [10, 25, 50, 100]: |
| | print(f"\n--- {steps} steps ---") |
| | result = long_diffusion_generation( |
| | model, prompt, target_chars=20, steps=steps |
| | ) |
| | results.append({ |
| | 'method': 'diffusion', |
| | 'steps': steps, |
| | 'result': result |
| | }) |
| | |
| | return results |
| |
|
| | def test_coherence_prompts(): |
| | """Test with prompts that should elicit more coherent responses.""" |
| | print("\nπ― === COHERENCE PROMPTS TESTS ===") |
| | print("Testing prompts designed to elicit coherent language patterns...") |
| | |
| | model = load_full_attention_model() |
| | |
| | |
| | coherence_prompts = [ |
| | "Once upon a time", |
| | "The quick brown fox", |
| | "In the beginning", |
| | "Python code to print hello:", |
| | "def main():", |
| | "SELECT * FROM", |
| | "Today is a beautiful", |
| | "My name is", |
| | "The answer is", |
| | "import torch" |
| | ] |
| | |
| | results = [] |
| | |
| | for prompt in coherence_prompts: |
| | print(f"\n--- Testing coherence with: '{prompt}' ---") |
| | |
| | |
| | auto_result = autoregressive_generate_with_parity_correction( |
| | model, prompt, max_new_chars=30, temperature=0.7 |
| | ) |
| | |
| | diff_result = long_diffusion_generation( |
| | model, prompt, target_chars=30, steps=75 |
| | ) |
| | |
| | results.append({ |
| | 'prompt': prompt, |
| | 'autoregressive': auto_result, |
| | 'diffusion': diff_result |
| | }) |
| | |
| | |
| | if auto_result.get('success'): |
| | auto_text = auto_result.get('new_text', '') |
| | if any(word in auto_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): |
| | print(f" π Autoregressive contains common words!") |
| | |
| | if diff_result.get('success'): |
| | diff_text = diff_result.get('new_text', '') |
| | if any(word in diff_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): |
| | print(f" π Diffusion contains common words!") |
| | |
| | return results |
| |
|
| | def main(): |
| | """Run all enhanced generation tests.""" |
| | print("π ENHANCED BITRANSFORMERLM GENERATION TESTING") |
| | print("=" * 60) |
| | print("Testing potential fixes:") |
| | print("1. Autoregressive with parity correction") |
| | print("2. Longer sequence generation") |
| | print("3. Optimized generation parameters") |
| | print("4. Coherence-focused prompts") |
| | print("=" * 60) |
| | |
| | |
| | length_results = test_length_scaling() |
| | param_results = test_parameter_optimization() |
| | coherence_results = test_coherence_prompts() |
| | |
| | |
| | print("\nπ― === OVERALL ANALYSIS ===") |
| | |
| | |
| | total_auto = len([r for results in [length_results, coherence_results] |
| | for r in results if r.get('autoregressive', {}).get('success')]) |
| | total_diff = len([r for results in [length_results, coherence_results] |
| | for r in results if r.get('diffusion', {}).get('success')]) |
| | |
| | print(f"Autoregressive success rate: {total_auto}/24") |
| | print(f"Diffusion success rate: {total_diff}/24") |
| | |
| | |
| | print("\nπ Looking for signs of linguistic improvement...") |
| | |
| | all_results = length_results + coherence_results |
| | promising_outputs = [] |
| | |
| | for result in all_results: |
| | for method in ['autoregressive', 'diffusion']: |
| | if result.get(method, {}).get('success'): |
| | text = result[method].get('new_text', '') |
| | |
| | if len(text) > 10 and any(c.isalpha() for c in text): |
| | words = text.split() |
| | if any(len(word) > 2 and word.isalpha() for word in words): |
| | promising_outputs.append({ |
| | 'prompt': result['prompt'], |
| | 'method': method, |
| | 'text': text |
| | }) |
| | |
| | if promising_outputs: |
| | print(f"\nπ Found {len(promising_outputs)} promising outputs with word-like patterns!") |
| | for output in promising_outputs[:5]: |
| | print(f" {output['method']}: '{output['prompt']}' β '{output['text']}'") |
| | else: |
| | print("\nπ No clear word patterns found yet - model may need more training or different approach") |
| | |
| | return { |
| | 'length_results': length_results, |
| | 'param_results': param_results, |
| | 'coherence_results': coherence_results, |
| | 'summary': { |
| | 'autoregressive_successes': total_auto, |
| | 'diffusion_successes': total_diff, |
| | 'promising_outputs': len(promising_outputs) |
| | } |
| | } |
| |
|
| | if __name__ == "__main__": |
| | results = main() |