| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from bit_transformer import text_to_bits, collapse_submodel | |
| from progressive_scaleup import progressive_scale_up_text | |
| def lines_to_bits(lines, max_len=64): | |
| data = [] | |
| for text in lines: | |
| bits = text_to_bits(text)[:max_len] | |
| if len(bits) < max_len: | |
| bits.extend([0] * (max_len - len(bits))) | |
| data.append(bits) | |
| return data | |
| def main(): | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") | |
| val_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation[:1%]") | |
| train_lines = [item["text"] for item in ds][:256] | |
| valid_lines = [item["text"] for item in val_ds][:64] | |
| train_bits = lines_to_bits(train_lines) | |
| valid_bits = lines_to_bits(valid_lines) | |
| progressive_scale_up_text( | |
| eps=0.65, | |
| steps=4, | |
| width_mult=2.0, | |
| max_len=64, | |
| dataset_size=min(64, len(train_bits)), | |
| ) | |
| target_params = dict(d_model=16, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=64) | |
| model, _ = collapse_submodel(train_bits[:64], target_params, max_rounds=1) | |
| val_tensor = torch.tensor(valid_bits, dtype=torch.long) | |
| logits, _ = model(val_tensor) | |
| pred = logits[:, :-1, :].reshape(-1, 2) | |
| target = val_tensor[:, 1:].reshape(-1) | |
| loss = F.cross_entropy(pred, target) | |
| print("Collapsed model validation loss:", loss.item()) | |
| if __name__ == "__main__": | |
| main() | |