Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-8b-remote-handoff /bundle /evaluation /bitsharp_training.py
| """BitSharp trainer: exactness-oriented TinyMind fine-tuning. | |
| Technique: Pure Bit-Margin Sharpening (PBMS) | |
| - train on clean expert rows; | |
| - compare each batch against deterministic corrupted labels; | |
| - enforce clean loss + margin < corrupted loss; | |
| - measure exact next-token accuracy and bit error proxy. | |
| """ | |
| from __future__ import annotations | |
| from datetime import datetime, timezone | |
| import json | |
| import math | |
| from pathlib import Path | |
| import random | |
| import torch | |
| from evaluation.local_evidence import _collate, _encode, _text | |
| from model.architecture import OmegaModel | |
| def _load_rows(paths: list[str | Path]) -> list[dict]: | |
| rows = [] | |
| for path in paths: | |
| p = Path(path) | |
| rows.extend(json.loads(line) for line in p.read_text(encoding="utf-8").splitlines() if line.strip()) | |
| return rows | |
| def _corrupt_labels(labels: torch.Tensor, vocab_size: int) -> torch.Tensor: | |
| corrupt = labels.clone() | |
| mask = corrupt >= 4 | |
| corrupt[mask] = 4 + ((corrupt[mask] - 3) % max(vocab_size - 4, 1)) | |
| return corrupt | |
| def _exact_metrics(model: OmegaModel, sequences: list[torch.Tensor]) -> dict: | |
| model.eval() | |
| input_ids, labels = _collate(sequences) | |
| out = model(input_ids, labels=labels) | |
| logits = out["logits"][..., :-1, :] | |
| target = labels[..., 1:] | |
| mask = target != -100 | |
| pred = logits.argmax(dim=-1) | |
| exact = ((pred == target) & mask).sum().item() | |
| total = mask.sum().item() | |
| token_acc = exact / max(total, 1) | |
| bit_error_proxy = 1.0 - token_acc | |
| return { | |
| "loss": float(out["loss"].item()), | |
| "perplexity": float(math.exp(min(float(out["loss"].item()), 20.0))), | |
| "next_token_accuracy": token_acc, | |
| "bit_error_proxy": bit_error_proxy, | |
| "tokens": int(total), | |
| } | |
| def run_bitsharp_training( | |
| checkpoint_path: str | Path, | |
| dataset_paths: list[str | Path], | |
| out_dir: str | Path, | |
| steps: int = 128, | |
| margin: float = 0.15, | |
| seed: int = 20260523, | |
| ) -> dict: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| out = Path(out_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| model = OmegaModel(ckpt["model_cfg"]) | |
| model.load_state_dict(ckpt["model_state"]) | |
| cfg = model.cfg | |
| rows = _load_rows(dataset_paths) | |
| sequences = [_encode(_text(row), cfg.max_seq_len, cfg.vocab_size) for row in rows] | |
| split = max(1, int(len(sequences) * 0.8)) | |
| train_sequences = sequences[:split] | |
| eval_sequences = sequences[split:] or sequences[-1:] | |
| before = _exact_metrics(model, eval_sequences) | |
| opt = torch.optim.AdamW(model.parameters(), lr=7e-4, weight_decay=0.02) | |
| losses = [] | |
| margins = [] | |
| model.train() | |
| for step in range(max(1, int(steps))): | |
| batch = [train_sequences[(step + j) % len(train_sequences)] for j in range(min(3, len(train_sequences)))] | |
| input_ids, labels = _collate(batch) | |
| clean = model(input_ids, labels=labels)["loss"] | |
| corrupt_labels = _corrupt_labels(labels, cfg.vocab_size) | |
| corrupt = model(input_ids, labels=corrupt_labels)["loss"] | |
| margin_loss = torch.relu(clean + margin - corrupt) | |
| loss = clean + 0.35 * margin_loss | |
| opt.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| losses.append(float(loss.item())) | |
| margins.append(float((corrupt - clean).item())) | |
| after = _exact_metrics(model, eval_sequences) | |
| checkpoint_out = out / "bitsharp_purefield.pt" | |
| torch.save( | |
| { | |
| **ckpt, | |
| "model_state": model.state_dict(), | |
| "bitsharp": {"steps": int(steps), "margin": margin, "before": before, "after": after}, | |
| }, | |
| checkpoint_out, | |
| ) | |
| report = { | |
| "schema_version": "tinymind-bitsharp-training-v1", | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "technique": "Pure Bit-Margin Sharpening", | |
| "claim_scope": "exactness-oriented fine-tuning with corrupted-label margin, not world-best proof", | |
| "world_best_claim_allowed": False, | |
| "input_checkpoint": str(checkpoint_path), | |
| "checkpoint": str(checkpoint_out), | |
| "dataset_paths": [str(path) for path in dataset_paths], | |
| "steps": int(steps), | |
| "margin": margin, | |
| "before": before, | |
| "after": after, | |
| "loss_tail": losses[-10:], | |
| "margin_tail": margins[-10:], | |
| "improved": after["loss"] < before["loss"] and after["next_token_accuracy"] >= before["next_token_accuracy"], | |
| } | |
| report_path = out / "bitsharp_report.json" | |
| report["report_path"] = str(report_path) | |
| report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True), encoding="utf-8") | |
| return report | |
Xet Storage Details
- Size:
- 4.8 kB
- Xet hash:
- 46a99800109cda58668793cd88b1eda8c32008e57c01a826bc1dde69ea30e3e5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.