| | |
| | """ |
| | Simple BitTransformerLM Test - No Interactive Input |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
| |
|
| | def test_breakthrough_model(): |
| | """Simple test of the breakthrough model.""" |
| | print("π Loading breakthrough BitTransformerLM...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=512, |
| | reversible=True, |
| | use_checkpoint=False, |
| | use_autocast=False, |
| | use_act=True, |
| | act_threshold=0.9, |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05 |
| | ) |
| | |
| | |
| | checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | print(f"β
Model loaded! Loss: {checkpoint['loss']:.6f}") |
| | |
| | |
| | prompts = [ |
| | "Hello", |
| | "Hi there", |
| | "What is your name?", |
| | "The weather is" |
| | ] |
| | |
| | for prompt in prompts: |
| | print(f"\nπ€ Testing: '{prompt}'") |
| | |
| | |
| | input_bits = text_to_bits(prompt) |
| | input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) |
| | |
| | print(f"π Input: {len(input_bits)} bits") |
| | |
| | with torch.no_grad(): |
| | try: |
| | |
| | logits, telemetry = model(input_tensor) |
| | |
| | |
| | next_probs = F.softmax(logits[0, -1, :], dim=-1) |
| | |
| | print(f"π― Next bit probs: [0]={next_probs[0]:.3f}, [1]={next_probs[1]:.3f}") |
| | |
| | if telemetry: |
| | k_val = telemetry.get('negentropy_logits', 0) |
| | c_val = telemetry.get('lz_complexity_logits', 0) |
| | s_val = telemetry.get('symbiosis_score', 0) |
| | |
| | |
| | if torch.is_tensor(k_val): |
| | k_val = k_val.mean().item() |
| | if torch.is_tensor(c_val): |
| | c_val = c_val.mean().item() |
| | if torch.is_tensor(s_val): |
| | s_val = s_val.mean().item() |
| | |
| | print(f"π Telemetry: K={k_val:.3f}, C={c_val:.3f}, S={s_val:.3f}") |
| | |
| | |
| | generated_bits = input_bits.copy() |
| | |
| | for i in range(18): |
| | current_tensor = torch.tensor(generated_bits, dtype=torch.long).unsqueeze(0) |
| | if current_tensor.size(1) > 500: |
| | current_tensor = current_tensor[:, -500:] |
| | |
| | logits, _ = model(current_tensor) |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | |
| | next_bit_logits = next_bit_logits / 0.8 |
| | probs = F.softmax(next_bit_logits, dim=-1) |
| | next_bit = torch.multinomial(probs, 1).item() |
| | |
| | generated_bits.append(next_bit) |
| | |
| | |
| | generated_only = generated_bits[len(input_bits):] |
| | try: |
| | generated_text = bits_to_text(generated_only) |
| | print(f"β¨ Generated: '{generated_text}'") |
| | except Exception as e: |
| | print(f"π§ Decode failed: {e}") |
| | print(f"Raw bits: {generated_only}") |
| | |
| | except Exception as e: |
| | print(f"β Model error: {e}") |
| |
|
| | if __name__ == "__main__": |
| | test_breakthrough_model() |