| | |
| | """ |
| | Better Sampling for BitTransformerLM |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
| |
|
| | def load_model(): |
| | model = BitTransformerLM( |
| | d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
| | max_seq_len=512, reversible=True, use_checkpoint=False, |
| | use_autocast=False, use_act=True, act_threshold=0.9, |
| | lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 |
| | ) |
| | |
| | checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | return model |
| |
|
| | def smart_generate(model, prompt, max_chars=5): |
| | """Generate with better sampling strategies.""" |
| | print(f"\nπ― Smart generating from: '{prompt}'") |
| | |
| | input_bits = text_to_bits(prompt) |
| | generated_bits = input_bits.copy() |
| | |
| | with torch.no_grad(): |
| | for char_idx in range(max_chars): |
| | |
| | char_bits = [] |
| | |
| | for bit_idx in range(9): |
| | |
| | context = generated_bits + char_bits |
| | context = context[-300:] if len(context) > 300 else context |
| | context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
| | |
| | logits, telemetry = model(context_tensor) |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | |
| | if bit_idx < 8: |
| | |
| | temperature = 0.8 |
| | next_bit_logits = next_bit_logits / temperature |
| | |
| | |
| | k = 2 |
| | top_k_logits, top_k_indices = torch.topk(next_bit_logits, k) |
| | probs = F.softmax(top_k_logits, dim=-1) |
| | selected_idx = torch.multinomial(probs, 1).item() |
| | next_bit = top_k_indices[selected_idx].item() |
| | else: |
| | |
| | data_bits = char_bits[:8] |
| | expected_parity = sum(data_bits) % 2 |
| | next_bit = expected_parity |
| | |
| | char_bits.append(next_bit) |
| | |
| | |
| | generated_bits.extend(char_bits) |
| | |
| | |
| | try: |
| | new_char_bits = char_bits |
| | |
| | data_bits = new_char_bits[:8] |
| | byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits)) |
| | |
| | if 32 <= byte_val <= 126: |
| | char = chr(byte_val) |
| | print(f" Char {char_idx+1}: '{char}' (byte={byte_val})") |
| | |
| | |
| | if char in '.!?\n': |
| | break |
| | else: |
| | print(f" Char {char_idx+1}: Non-printable (byte={byte_val})") |
| | |
| | except Exception as e: |
| | print(f" Char {char_idx+1}: Decode error: {e}") |
| | |
| | |
| | generated_only = generated_bits[len(input_bits):] |
| | try: |
| | final_text = bits_to_text(generated_only) |
| | print(f"β¨ Result: '{prompt}' + '{final_text}'") |
| | return final_text |
| | except Exception as e: |
| | print(f"β Final decode failed: {e}") |
| | |
| | |
| | manual_result = "" |
| | for i in range(0, len(generated_only), 9): |
| | if i + 8 < len(generated_only): |
| | char_bits = generated_only[i:i+8] |
| | byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
| | if 32 <= byte_val <= 126: |
| | manual_result += chr(byte_val) |
| | else: |
| | manual_result += '?' |
| | |
| | print(f"π§ Manual decode: '{prompt}' + '{manual_result}'") |
| | return manual_result |
| |
|
| | def main(): |
| | print("π SMART BITRANSFORMERLM GENERATION") |
| | print("=" * 40) |
| | |
| | model = load_model() |
| | print("β
Model loaded!") |
| | |
| | |
| | prompts = [ |
| | "Hello", |
| | "Hi", |
| | "A", |
| | "The cat", |
| | "I am", |
| | "Yes", |
| | "No" |
| | ] |
| | |
| | for prompt in prompts: |
| | result = smart_generate(model, prompt, max_chars=4) |
| |
|
| | if __name__ == "__main__": |
| | main() |