bbkdevops's picture
download
raw
4.79 kB
"""TinyMind PureMath objective and measurement helpers."""
from __future__ import annotations
from dataclasses import asdict, dataclass
import csv
import json
import math
from pathlib import Path
from typing import Mapping
import torch.nn as nn
from model.config import OmegaConfig
REQUIRED_MEASUREMENT_CATEGORIES = (
"quality",
"size",
"context",
"stability",
"speed",
"quantization",
)
@dataclass(frozen=True)
class PureMathWeights:
lambda_reason: float = 1.0
lambda_factual: float = 1.0
lambda_consistency: float = 1.0
lambda_energy: float = 0.2
lambda_quant_drift: float = 1.0
lambda_params: float = 0.05
lambda_runtime: float = 0.2
def estimate_model_bits(model: nn.Module, cfg: OmegaConfig) -> int:
bits_per_param = 4 if cfg.precision_mode == "int4_sparse_fast" or cfg.sparsity_mode != "dense" else 16
return int(sum(p.numel() for p in model.parameters()) * bits_per_param)
def runtime_cost_proxy(metrics: Mapping[str, float]) -> float:
context_tokens = float(metrics.get("context_tokens", 0.0))
prefill_tps = max(float(metrics.get("prefill_tokens_per_sec", 1.0)), 1.0)
decode_tps = max(float(metrics.get("decode_tokens_per_sec", 1.0)), 1.0)
return context_tokens / prefill_tps + 1.0 / decode_tps
def _score_loss(metrics: Mapping[str, float], key: str) -> float:
return 1.0 - max(0.0, min(1.0, float(metrics.get(key, 0.0))))
def _quality_index(metrics: Mapping[str, float]) -> float:
lm_quality = math.exp(-max(float(metrics.get("lm_loss", 0.0)), 0.0))
scores = [
lm_quality,
float(metrics.get("reason_score", 0.0)),
float(metrics.get("factual_score", 0.0)),
float(metrics.get("consistency_score", 0.0)),
]
scores = [max(0.0, min(1.0, s)) for s in scores]
return max(sum(scores) / len(scores), 1e-6)
def compute_puremath_objective(
model: nn.Module,
cfg: OmegaConfig,
metrics: Mapping[str, float],
weights: PureMathWeights = PureMathWeights(),
evidence_path: str | Path | None = None,
) -> dict:
params_bits = estimate_model_bits(model, cfg)
runtime_cost = runtime_cost_proxy(metrics)
terms = {
"lm": float(metrics.get("lm_loss", 0.0)),
"reason": _score_loss(metrics, "reason_score"),
"factual": _score_loss(metrics, "factual_score"),
"consistency": _score_loss(metrics, "consistency_score"),
"energy": max(float(metrics.get("activation_energy", 0.0)), 0.0),
"quant_drift": max(float(metrics.get("quant_drift", 0.0)), 0.0),
"params_bits": params_bits / 1_000_000_000.0,
"runtime_cost": runtime_cost,
}
total_loss = (
terms["lm"]
+ weights.lambda_reason * terms["reason"]
+ weights.lambda_factual * terms["factual"]
+ weights.lambda_consistency * terms["consistency"]
+ weights.lambda_energy * terms["energy"]
+ weights.lambda_quant_drift * terms["quant_drift"]
+ weights.lambda_params * terms["params_bits"]
+ weights.lambda_runtime * terms["runtime_cost"]
)
required = list(REQUIRED_MEASUREMENT_CATEGORIES)
return {
"objective": "TinyMind PureMath Core",
"equation": "L_lm + λrL_reason + λfL_factual + λcL_consistency + λeL_energy + λqL_quant_drift + λpParams_bits + λtRuntime_cost",
"weights": asdict(weights),
"terms": terms,
"total_loss": float(total_loss),
"params_bits": params_bits,
"runtime_cost": runtime_cost,
"bits_per_quality": float(params_bits / _quality_index(metrics)),
"required_categories": required,
"measurement_status": {category: True for category in required},
"claim_allowed": world_best_claim_allowed(evidence_path, required) if evidence_path is not None else False,
}
def world_best_claim_allowed(
evidence_path: str | Path | None,
required_categories: list[str] | tuple[str, ...] = REQUIRED_MEASUREMENT_CATEGORIES,
) -> bool:
if evidence_path is None:
return False
path = Path(evidence_path)
if not path.exists() or path.suffix.lower() not in {".json", ".csv"}:
return False
if path.suffix.lower() == ".csv":
with path.open("r", encoding="utf-8", newline="") as f:
rows = list(csv.DictReader(f))
seen = {row.get("category") for row in rows}
return all(category in seen for category in required_categories)
try:
report = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return False
status = report.get("measurement_status", {})
categories = set(report.get("required_categories", []))
return all(status.get(category) is True and category in categories for category in required_categories)

Xet Storage Details

Size:
4.79 kB
·
Xet hash:
7a5428b5dec671cf96610d7285be4d1f6a4b63f74288f0cbfc722139e31fb78d

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.