Remove nested directory: BitTransformerLM/bit_transformer/telemetry.py
Browse files
BitTransformerLM/bit_transformer/telemetry.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from typing import Dict, List, TYPE_CHECKING
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from sklearn.cluster import KMeans
|
| 6 |
-
|
| 7 |
-
if TYPE_CHECKING: # pragma: no cover
|
| 8 |
-
from .model import BitTransformerLM
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TelemetrySynthesizer:
|
| 12 |
-
"""Analyze telemetry batches and cluster activation patterns."""
|
| 13 |
-
|
| 14 |
-
def __init__(self, n_clusters: int = 2) -> None:
|
| 15 |
-
self.n_clusters = n_clusters
|
| 16 |
-
|
| 17 |
-
def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
|
| 18 |
-
"""Compute activation/attention summaries for a single telemetry dict."""
|
| 19 |
-
acts = telemetry["activations"]
|
| 20 |
-
attn = telemetry["attention_maps"]
|
| 21 |
-
summaries = []
|
| 22 |
-
for a, m in zip(acts, attn):
|
| 23 |
-
mean = a.mean().item()
|
| 24 |
-
var = a.var(unbiased=False).item()
|
| 25 |
-
prob = m.softmax(-1)
|
| 26 |
-
entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
|
| 27 |
-
summaries.append([mean, var, entropy])
|
| 28 |
-
return np.array(summaries).ravel()
|
| 29 |
-
|
| 30 |
-
def synthesize(
|
| 31 |
-
self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
|
| 32 |
-
) -> Dict[str, List]:
|
| 33 |
-
"""Cluster telemetry summaries and return cluster info."""
|
| 34 |
-
data = np.stack([self._summary(t) for t in telemetries])
|
| 35 |
-
km = KMeans(n_clusters=self.n_clusters, n_init=1)
|
| 36 |
-
labels = km.fit_predict(data)
|
| 37 |
-
representatives: List[List[int]] = []
|
| 38 |
-
for c in range(self.n_clusters):
|
| 39 |
-
idx = np.where(labels == c)[0]
|
| 40 |
-
if len(idx) > 0:
|
| 41 |
-
representatives.append(bit_seqs[idx[0]].tolist())
|
| 42 |
-
else:
|
| 43 |
-
representatives.append([])
|
| 44 |
-
return {"cluster_assignments": labels.tolist(), "representatives": representatives}
|
| 45 |
-
|
| 46 |
-
def cluster_sequences(
|
| 47 |
-
self, model: "BitTransformerLM", bit_seqs: torch.Tensor
|
| 48 |
-
) -> List[List[int]]:
|
| 49 |
-
"""Run the model to gather telemetry and return representative sequences.
|
| 50 |
-
|
| 51 |
-
Parameters
|
| 52 |
-
----------
|
| 53 |
-
model: BitTransformerLM
|
| 54 |
-
Model used to compute telemetry for each sequence.
|
| 55 |
-
bit_seqs: torch.Tensor
|
| 56 |
-
Tensor containing one bit sequence per row.
|
| 57 |
-
|
| 58 |
-
Returns
|
| 59 |
-
-------
|
| 60 |
-
list[list[int]]
|
| 61 |
-
Representative sequences chosen from KMeans clusters.
|
| 62 |
-
"""
|
| 63 |
-
telemetries: List[Dict[str, List[torch.Tensor]]] = []
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
for seq in bit_seqs:
|
| 66 |
-
_, tele = model(seq.unsqueeze(0))
|
| 67 |
-
telemetries.append(tele)
|
| 68 |
-
info = self.synthesize(telemetries, bit_seqs)
|
| 69 |
-
return info["representatives"]
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def detect_metric_drift(
|
| 73 |
-
metrics_log: Dict[str, List[float]],
|
| 74 |
-
window: int = 10,
|
| 75 |
-
threshold: float = 0.2,
|
| 76 |
-
) -> Dict[str, bool]:
|
| 77 |
-
"""Detect metric drift between consecutive windows.
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
metrics_log: History of scalar metrics keyed by name.
|
| 81 |
-
window: Number of recent steps to compare.
|
| 82 |
-
threshold: Absolute difference required to flag drift.
|
| 83 |
-
|
| 84 |
-
Returns:
|
| 85 |
-
Dictionary mapping metric keys to a boolean drift indicator.
|
| 86 |
-
"""
|
| 87 |
-
drift = {}
|
| 88 |
-
for key, values in metrics_log.items():
|
| 89 |
-
if len(values) < window * 2:
|
| 90 |
-
drift[key] = False
|
| 91 |
-
continue
|
| 92 |
-
recent = np.mean(values[-window:])
|
| 93 |
-
prev = np.mean(values[-2 * window : -window])
|
| 94 |
-
drift[key] = abs(recent - prev) > threshold
|
| 95 |
-
return drift
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|