#!/usr/bin/env python3 """ Debug BitTransformerLM Generation """ import sys import torch import torch.nn.functional as F sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text def load_model(): model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, use_autocast=False, use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 ) checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, checkpoint['loss'] def generate_longer(model, prompt, num_chars=10): """Generate longer sequences.""" print(f"\nšŸŽÆ Generating {num_chars} characters from: '{prompt}'") input_bits = text_to_bits(prompt) print(f"Input: {len(input_bits)} bits") generated_bits = input_bits.copy() with torch.no_grad(): # Generate num_chars * 9 bits (9 bits per character with parity) for i in range(num_chars * 9): # Use last 400 bits to stay within context context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0) logits, telemetry = model(context_tensor) next_bit_logits = logits[0, -1, :] # Temperature sampling temperature = 0.7 next_bit_logits = next_bit_logits / temperature probs = F.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() generated_bits.append(next_bit) # Try to decode every 9 bits if (i + 1) % 9 == 0: generated_only = generated_bits[len(input_bits):] try: partial_text = bits_to_text(generated_only) print(f" After {(i+1)//9} chars: '{partial_text}'") except: pass # Final decode generated_only = generated_bits[len(input_bits):] try: final_text = bits_to_text(generated_only) print(f"✨ Final result: '{prompt}' + '{final_text}'") return final_text except Exception as e: print(f"āŒ Final decode failed: {e}") print(f"Generated {len(generated_only)} bits: {generated_only[:50]}...") # Try to decode in chunks print("šŸ”§ Trying chunk decoding...") for chunk_size in [9, 18, 27]: # 1, 2, 3 characters if len(generated_only) >= chunk_size: try: chunk_text = bits_to_text(generated_only[:chunk_size]) print(f" First {chunk_size//9} chars: '{chunk_text}'") except Exception as ce: print(f" {chunk_size//9} chars failed: {ce}") return None def test_bit_encoding(): """Test the bit encoding/decoding functions.""" print("\nšŸ”§ Testing bit encoding/decoding...") test_strings = ["A", "AB", "Hello", "Hi there!"] for s in test_strings: bits = text_to_bits(s) try: decoded = bits_to_text(bits) status = "āœ…" if decoded == s else "āŒ" print(f"{status} '{s}' -> {len(bits)} bits -> '{decoded}'") except Exception as e: print(f"āŒ '{s}' -> {len(bits)} bits -> ERROR: {e}") def main(): print("šŸš€ BITRANSFORMERLM GENERATION DEBUG") print("=" * 50) # Test encoding first test_bit_encoding() # Load model model, loss = load_model() print(f"\nāœ… Model loaded! Loss: {loss:.6f}") # Test generation prompts = ["Hello", "Hi", "A", "The"] for prompt in prompts: generate_longer(model, prompt, num_chars=3) if __name__ == "__main__": main()