bbkdevops's picture
download
raw
4.8 kB
"""BitSharp trainer: exactness-oriented TinyMind fine-tuning.
Technique: Pure Bit-Margin Sharpening (PBMS)
- train on clean expert rows;
- compare each batch against deterministic corrupted labels;
- enforce clean loss + margin < corrupted loss;
- measure exact next-token accuracy and bit error proxy.
"""
from __future__ import annotations
from datetime import datetime, timezone
import json
import math
from pathlib import Path
import random
import torch
from evaluation.local_evidence import _collate, _encode, _text
from model.architecture import OmegaModel
def _load_rows(paths: list[str | Path]) -> list[dict]:
rows = []
for path in paths:
p = Path(path)
rows.extend(json.loads(line) for line in p.read_text(encoding="utf-8").splitlines() if line.strip())
return rows
def _corrupt_labels(labels: torch.Tensor, vocab_size: int) -> torch.Tensor:
corrupt = labels.clone()
mask = corrupt >= 4
corrupt[mask] = 4 + ((corrupt[mask] - 3) % max(vocab_size - 4, 1))
return corrupt
@torch.no_grad()
def _exact_metrics(model: OmegaModel, sequences: list[torch.Tensor]) -> dict:
model.eval()
input_ids, labels = _collate(sequences)
out = model(input_ids, labels=labels)
logits = out["logits"][..., :-1, :]
target = labels[..., 1:]
mask = target != -100
pred = logits.argmax(dim=-1)
exact = ((pred == target) & mask).sum().item()
total = mask.sum().item()
token_acc = exact / max(total, 1)
bit_error_proxy = 1.0 - token_acc
return {
"loss": float(out["loss"].item()),
"perplexity": float(math.exp(min(float(out["loss"].item()), 20.0))),
"next_token_accuracy": token_acc,
"bit_error_proxy": bit_error_proxy,
"tokens": int(total),
}
def run_bitsharp_training(
checkpoint_path: str | Path,
dataset_paths: list[str | Path],
out_dir: str | Path,
steps: int = 128,
margin: float = 0.15,
seed: int = 20260523,
) -> dict:
random.seed(seed)
torch.manual_seed(seed)
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model = OmegaModel(ckpt["model_cfg"])
model.load_state_dict(ckpt["model_state"])
cfg = model.cfg
rows = _load_rows(dataset_paths)
sequences = [_encode(_text(row), cfg.max_seq_len, cfg.vocab_size) for row in rows]
split = max(1, int(len(sequences) * 0.8))
train_sequences = sequences[:split]
eval_sequences = sequences[split:] or sequences[-1:]
before = _exact_metrics(model, eval_sequences)
opt = torch.optim.AdamW(model.parameters(), lr=7e-4, weight_decay=0.02)
losses = []
margins = []
model.train()
for step in range(max(1, int(steps))):
batch = [train_sequences[(step + j) % len(train_sequences)] for j in range(min(3, len(train_sequences)))]
input_ids, labels = _collate(batch)
clean = model(input_ids, labels=labels)["loss"]
corrupt_labels = _corrupt_labels(labels, cfg.vocab_size)
corrupt = model(input_ids, labels=corrupt_labels)["loss"]
margin_loss = torch.relu(clean + margin - corrupt)
loss = clean + 0.35 * margin_loss
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
losses.append(float(loss.item()))
margins.append(float((corrupt - clean).item()))
after = _exact_metrics(model, eval_sequences)
checkpoint_out = out / "bitsharp_purefield.pt"
torch.save(
{
**ckpt,
"model_state": model.state_dict(),
"bitsharp": {"steps": int(steps), "margin": margin, "before": before, "after": after},
},
checkpoint_out,
)
report = {
"schema_version": "tinymind-bitsharp-training-v1",
"created_at": datetime.now(timezone.utc).isoformat(),
"technique": "Pure Bit-Margin Sharpening",
"claim_scope": "exactness-oriented fine-tuning with corrupted-label margin, not world-best proof",
"world_best_claim_allowed": False,
"input_checkpoint": str(checkpoint_path),
"checkpoint": str(checkpoint_out),
"dataset_paths": [str(path) for path in dataset_paths],
"steps": int(steps),
"margin": margin,
"before": before,
"after": after,
"loss_tail": losses[-10:],
"margin_tail": margins[-10:],
"improved": after["loss"] < before["loss"] and after["next_token_accuracy"] >= before["next_token_accuracy"],
}
report_path = out / "bitsharp_report.json"
report["report_path"] = str(report_path)
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True), encoding="utf-8")
return report

Xet Storage Details

Size:
4.8 kB
·
Xet hash:
46a99800109cda58668793cd88b1eda8c32008e57c01a826bc1dde69ea30e3e5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.