File size: 3,240 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import os
from typing import Dict, List, Optional, Tuple

import torch

from .model import BitTransformerLM
from .training import train_loop


def collapse_submodel(
    cluster_data: List[List[int]],
    target_params: Dict,
    floors: Optional[Dict[str, float]] = None,
    max_rounds: int = 3,
    width_scale: float = 1.5,
    forward_kwargs: Optional[Dict] = None,
) -> Tuple[BitTransformerLM, Dict[str, float]]:
    """Distill a submodel from clustered bit sequences.

    The routine deepens the target model when telemetry floors are unmet and,
    after the first deepening fails, widens the hidden dimensions by
    ``width_scale`` once before retrying. Returns the distilled model and its
    final telemetry metrics.
    """
    if floors is None:
        floors = {"negentropy": 0.5, "lz_complexity": 0.3, "symbiosis_score": 0.5}

    bit_tensor = torch.tensor(cluster_data, dtype=torch.long)
    n = len(bit_tensor)
    split = max(1, int(0.8 * n))
    train_bits = bit_tensor[:split]
    val_bits = bit_tensor[split:]
    if len(val_bits) == 0:
        val_bits = train_bits

    params = target_params.copy()
    metrics: Dict[str, float] = {}
    width_scaled = False
    for round_idx in range(max_rounds):
        model = BitTransformerLM(**params)
        train_loop(
            model,
            train_bits,
            epochs=2,
            compress_prob=0.5,
            direct_prob=0.0,
            log=False,
            forward_kwargs=forward_kwargs,
        )
        with torch.no_grad():
            logits, telemetry = model(val_bits, **(forward_kwargs or {}))
            neg_k = model.negentropy_logits(logits).mean().item()
            lz_c = model.lz_complexity_logits(logits).mean().item()
            sym_s = telemetry["symbiosis_score"].mean().item()
        metrics = {
            "negentropy": neg_k,
            "lz_complexity": lz_c,
            "symbiosis_score": sym_s,
        }
        if (
            neg_k >= floors["negentropy"]
            and lz_c >= floors["lz_complexity"]
            and sym_s >= floors["symbiosis_score"]
        ):
            break
        if round_idx == 0:
            params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
        elif not width_scaled:
            params["d_model"] = int(params.get("d_model", 32) * width_scale)
            params["dim_feedforward"] = int(
                params.get("dim_feedforward", 64) * width_scale
            )
            width_scaled = True
        else:
            params["num_layers"] = max(1, params.get("num_layers", 1)) + 1
    return model, metrics


def save_distilled_model(
    model: BitTransformerLM,
    path: str,
    metrics: Dict[str, float],
    floors: Optional[Dict[str, float]] = None,
) -> None:
    """Serialize a distilled model and its metric summary to disk.

    Weights are written to ``path`` and a ``metrics.json`` file is placed in the
    same directory containing the achieved metrics alongside the target floors.
    """
    torch.save(model.state_dict(), path)
    payload = {"metrics": metrics, "floors": floors or {}}
    metrics_path = os.path.join(os.path.dirname(path), "metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(payload, f)