Remove nested directory: BitTransformerLM/bit_transformer/collapse.py
Browse files
BitTransformerLM/bit_transformer/collapse.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
from typing import Dict, List, Optional, Tuple
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .model import BitTransformerLM
|
| 8 |
-
from .training import train_loop
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def collapse_submodel(
|
| 12 |
-
cluster_data: List[List[int]],
|
| 13 |
-
target_params: Dict,
|
| 14 |
-
floors: Optional[Dict[str, float]] = None,
|
| 15 |
-
max_rounds: int = 3,
|
| 16 |
-
width_scale: float = 1.5,
|
| 17 |
-
forward_kwargs: Optional[Dict] = None,
|
| 18 |
-
) -> Tuple[BitTransformerLM, Dict[str, float]]:
|
| 19 |
-
"""Distill a submodel from clustered bit sequences.
|
| 20 |
-
|
| 21 |
-
The routine deepens the target model when telemetry floors are unmet and,
|
| 22 |
-
after the first deepening fails, widens the hidden dimensions by
|
| 23 |
-
``width_scale`` once before retrying. Returns the distilled model and its
|
| 24 |
-
final telemetry metrics.
|
| 25 |
-
"""
|
| 26 |
-
if floors is None:
|
| 27 |
-
floors = {"negentropy": 0.5, "lz_complexity": 0.3, "symbiosis_score": 0.5}
|
| 28 |
-
|
| 29 |
-
bit_tensor = torch.tensor(cluster_data, dtype=torch.long)
|
| 30 |
-
n = len(bit_tensor)
|
| 31 |
-
split = max(1, int(0.8 * n))
|
| 32 |
-
train_bits = bit_tensor[:split]
|
| 33 |
-
val_bits = bit_tensor[split:]
|
| 34 |
-
if len(val_bits) == 0:
|
| 35 |
-
val_bits = train_bits
|
| 36 |
-
|
| 37 |
-
params = target_params.copy()
|
| 38 |
-
metrics: Dict[str, float] = {}
|
| 39 |
-
width_scaled = False
|
| 40 |
-
for round_idx in range(max_rounds):
|
| 41 |
-
model = BitTransformerLM(**params)
|
| 42 |
-
train_loop(
|
| 43 |
-
model,
|
| 44 |
-
train_bits,
|
| 45 |
-
epochs=2,
|
| 46 |
-
compress_prob=0.5,
|
| 47 |
-
direct_prob=0.0,
|
| 48 |
-
log=False,
|
| 49 |
-
forward_kwargs=forward_kwargs,
|
| 50 |
-
)
|
| 51 |
-
with torch.no_grad():
|
| 52 |
-
logits, telemetry = model(val_bits, **(forward_kwargs or {}))
|
| 53 |
-
neg_k = model.negentropy_logits(logits).mean().item()
|
| 54 |
-
lz_c = model.lz_complexity_logits(logits).mean().item()
|
| 55 |
-
sym_s = telemetry["symbiosis_score"].mean().item()
|
| 56 |
-
metrics = {
|
| 57 |
-
"negentropy": neg_k,
|
| 58 |
-
"lz_complexity": lz_c,
|
| 59 |
-
"symbiosis_score": sym_s,
|
| 60 |
-
}
|
| 61 |
-
if (
|
| 62 |
-
neg_k >= floors["negentropy"]
|
| 63 |
-
and lz_c >= floors["lz_complexity"]
|
| 64 |
-
and sym_s >= floors["symbiosis_score"]
|
| 65 |
-
):
|
| 66 |
-
break
|
| 67 |
-
if round_idx == 0:
|
| 68 |
-
params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
|
| 69 |
-
elif not width_scaled:
|
| 70 |
-
params["d_model"] = int(params.get("d_model", 32) * width_scale)
|
| 71 |
-
params["dim_feedforward"] = int(
|
| 72 |
-
params.get("dim_feedforward", 64) * width_scale
|
| 73 |
-
)
|
| 74 |
-
width_scaled = True
|
| 75 |
-
else:
|
| 76 |
-
params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
|
| 77 |
-
return model, metrics
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def save_distilled_model(
|
| 81 |
-
model: BitTransformerLM,
|
| 82 |
-
path: str,
|
| 83 |
-
metrics: Dict[str, float],
|
| 84 |
-
floors: Optional[Dict[str, float]] = None,
|
| 85 |
-
) -> None:
|
| 86 |
-
"""Serialize a distilled model and its metric summary to disk.
|
| 87 |
-
|
| 88 |
-
Weights are written to ``path`` and a ``metrics.json`` file is placed in the
|
| 89 |
-
same directory containing the achieved metrics alongside the target floors.
|
| 90 |
-
"""
|
| 91 |
-
torch.save(model.state_dict(), path)
|
| 92 |
-
payload = {"metrics": metrics, "floors": floors or {}}
|
| 93 |
-
metrics_path = os.path.join(os.path.dirname(path), "metrics.json")
|
| 94 |
-
with open(metrics_path, "w") as f:
|
| 95 |
-
json.dump(payload, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|