| | |
| | """ |
| | Debug BitTransformerLM Generation |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
| |
|
| | def load_model(): |
| | model = BitTransformerLM( |
| | d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
| | max_seq_len=512, reversible=True, use_checkpoint=False, |
| | use_autocast=False, use_act=True, act_threshold=0.9, |
| | lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 |
| | ) |
| | |
| | checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | return model, checkpoint['loss'] |
| |
|
| | def generate_longer(model, prompt, num_chars=10): |
| | """Generate longer sequences.""" |
| | print(f"\nπ― Generating {num_chars} characters from: '{prompt}'") |
| | |
| | input_bits = text_to_bits(prompt) |
| | print(f"Input: {len(input_bits)} bits") |
| | |
| | generated_bits = input_bits.copy() |
| | |
| | with torch.no_grad(): |
| | |
| | for i in range(num_chars * 9): |
| | |
| | context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits |
| | context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0) |
| | |
| | logits, telemetry = model(context_tensor) |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | |
| | temperature = 0.7 |
| | next_bit_logits = next_bit_logits / temperature |
| | probs = F.softmax(next_bit_logits, dim=-1) |
| | next_bit = torch.multinomial(probs, 1).item() |
| | |
| | generated_bits.append(next_bit) |
| | |
| | |
| | if (i + 1) % 9 == 0: |
| | generated_only = generated_bits[len(input_bits):] |
| | try: |
| | partial_text = bits_to_text(generated_only) |
| | print(f" After {(i+1)//9} chars: '{partial_text}'") |
| | except: |
| | pass |
| | |
| | |
| | generated_only = generated_bits[len(input_bits):] |
| | try: |
| | final_text = bits_to_text(generated_only) |
| | print(f"β¨ Final result: '{prompt}' + '{final_text}'") |
| | return final_text |
| | except Exception as e: |
| | print(f"β Final decode failed: {e}") |
| | print(f"Generated {len(generated_only)} bits: {generated_only[:50]}...") |
| | |
| | |
| | print("π§ Trying chunk decoding...") |
| | for chunk_size in [9, 18, 27]: |
| | if len(generated_only) >= chunk_size: |
| | try: |
| | chunk_text = bits_to_text(generated_only[:chunk_size]) |
| | print(f" First {chunk_size//9} chars: '{chunk_text}'") |
| | except Exception as ce: |
| | print(f" {chunk_size//9} chars failed: {ce}") |
| | |
| | return None |
| |
|
| | def test_bit_encoding(): |
| | """Test the bit encoding/decoding functions.""" |
| | print("\nπ§ Testing bit encoding/decoding...") |
| | |
| | test_strings = ["A", "AB", "Hello", "Hi there!"] |
| | |
| | for s in test_strings: |
| | bits = text_to_bits(s) |
| | try: |
| | decoded = bits_to_text(bits) |
| | status = "β
" if decoded == s else "β" |
| | print(f"{status} '{s}' -> {len(bits)} bits -> '{decoded}'") |
| | except Exception as e: |
| | print(f"β '{s}' -> {len(bits)} bits -> ERROR: {e}") |
| |
|
| | def main(): |
| | print("π BITRANSFORMERLM GENERATION DEBUG") |
| | print("=" * 50) |
| | |
| | |
| | test_bit_encoding() |
| | |
| | |
| | model, loss = load_model() |
| | print(f"\nβ
Model loaded! Loss: {loss:.6f}") |
| | |
| | |
| | prompts = ["Hello", "Hi", "A", "The"] |
| | |
| | for prompt in prompts: |
| | generate_longer(model, prompt, num_chars=3) |
| |
|
| | if __name__ == "__main__": |
| | main() |