Spaces:
Paused
Paused
File size: 7,826 Bytes
751ad26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Mapping
EvaluationSplit = Literal["calibration", "held_out"]
EvaluationLane = Literal["systems", "quality", "diagnostic"]
PromptFamily = Literal["synthetic_exact_length", "held_out_natural_text", "standardized_long_context"]
HarnessTruthType = Literal["reference_trace", "paged_runtime"]
_PAGE_MODE_FIELDS: tuple[str, ...] = ("m0", "m1", "m2", "m3", "m4", "t3")
def _coerce_non_negative_int(value: Any, *, field_name: str) -> int:
parsed = int(value)
if parsed < 0:
raise ValueError(f"{field_name} must be non-negative")
return parsed
def _coerce_positive_int(value: Any, *, field_name: str) -> int:
parsed = int(value)
if parsed <= 0:
raise ValueError(f"{field_name} must be positive")
return parsed
def page_format_histogram_from_result(result: Mapping[str, Any]) -> dict[str, int]:
histogram: dict[str, int] = {}
for kind_prefix in ("k", "v"):
kind_label = kind_prefix.upper()
for mode in _PAGE_MODE_FIELDS:
key = f"{kind_prefix}_{mode}_pages"
if key not in result:
continue
histogram[f"{kind_label}:{mode.upper()}"] = max(int(result.get(key, 0) or 0), 0)
return histogram
def _pick_context_length(result: Mapping[str, Any]) -> int | None:
for key in ("context_length_effective", "context_length", "prompt_length", "sequence_length", "prefix_length"):
value = result.get(key)
if value is None:
continue
try:
parsed = int(value)
except (TypeError, ValueError):
continue
if parsed > 0:
return parsed
return None
def _pick_float(result: Mapping[str, Any], *keys: str) -> float | None:
for key in keys:
if key not in result:
continue
value = result.get(key)
if value is None:
continue
try:
return float(value)
except (TypeError, ValueError):
continue
return None
@dataclass(slots=True)
class EvaluationMetadata:
model_id: str
model_family: str
backend: str
device: str
torch_dtype: str
split: EvaluationSplit
lane: EvaluationLane
prompt_family: PromptFamily
dataset_name: str
prompt_count: int
batch_size: int
truth_type: HarnessTruthType
effective_budget_rule: str
context_length: int | None = None
decode_steps: int | None = None
eval_steps: int | None = None
notes: list[str] = field(default_factory=list)
def __post_init__(self) -> None:
if not self.model_id:
raise ValueError("model_id is required")
if not self.model_family:
raise ValueError("model_family is required")
if not self.backend:
raise ValueError("backend is required")
if not self.device:
raise ValueError("device is required")
if not self.torch_dtype:
raise ValueError("torch_dtype is required")
if self.split not in ("calibration", "held_out"):
raise ValueError("split must be calibration or held_out")
if self.lane not in ("systems", "quality", "diagnostic"):
raise ValueError("lane must be systems, quality, or diagnostic")
if self.prompt_family not in ("synthetic_exact_length", "held_out_natural_text", "standardized_long_context"):
raise ValueError("prompt_family is invalid")
if not self.dataset_name:
raise ValueError("dataset_name is required")
self.prompt_count = _coerce_positive_int(self.prompt_count, field_name="prompt_count")
self.batch_size = _coerce_positive_int(self.batch_size, field_name="batch_size")
if not self.truth_type:
raise ValueError("truth_type is required")
if self.truth_type not in ("reference_trace", "paged_runtime"):
raise ValueError("truth_type must be reference_trace or paged_runtime")
if not self.effective_budget_rule:
raise ValueError("effective_budget_rule is required")
if self.context_length is not None:
self.context_length = _coerce_positive_int(self.context_length, field_name="context_length")
if self.decode_steps is not None:
self.decode_steps = _coerce_non_negative_int(self.decode_steps, field_name="decode_steps")
if self.eval_steps is not None:
self.eval_steps = _coerce_non_negative_int(self.eval_steps, field_name="eval_steps")
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(slots=True)
class EvaluationRecord:
metadata: EvaluationMetadata
metrics: dict[str, Any]
source_result: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"metadata": self.metadata.to_dict(),
"metrics": dict(self.metrics),
"source_result": dict(self.source_result),
}
def derive_standard_metrics(result: Mapping[str, Any], metadata: EvaluationMetadata | None = None) -> dict[str, Any]:
metrics: dict[str, Any] = {}
systems_keys = (
"dotcache_decode_ms_per_step",
"ttft_ms",
"p95_decode_ms_per_step",
"resident_bytes",
"prefill_ms",
"execution_shortlist_selected_pages",
"execution_shortlist_total_pages",
)
for key in systems_keys:
if key in result:
metrics[key] = result[key]
effective_bytes_per_token = _pick_float(result, "effective_bytes_per_token")
if effective_bytes_per_token is None:
resident_bytes = _pick_float(result, "resident_bytes")
context_length = metadata.context_length if metadata is not None else None
if context_length is None:
context_length = _pick_context_length(result)
if resident_bytes is not None and context_length is not None and context_length > 0:
effective_bytes_per_token = resident_bytes / float(context_length)
if effective_bytes_per_token is not None:
metrics["effective_bytes_per_token"] = float(effective_bytes_per_token)
quality_keys = (
"teacher_forced_loss_delta",
"teacher_forced_perplexity_ratio",
"teacher_forced_logit_max_abs_error",
"teacher_forced_logit_mean_abs_error",
"teacher_forced_logit_rmse",
"teacher_forced_token_agreement_rate",
"teacher_forced_target_match_rate",
)
for key in quality_keys:
if key in result:
metrics[key] = result[key]
diagnostic_keys = (
"shortlist_recall_exact_top_recall_mean",
"execution_decode_shortlist_selection_ms_total",
"execution_decode_shortlist_candidate_approx_scoring_ms_total",
"execution_decode_shortlist_candidate_ranking_ms_total",
"execution_decode_shortlist_materialization_ms_total",
"execution_decode_backend_call_non_backend_ms_total",
"replay_output_max_abs_error",
)
for key in diagnostic_keys:
if key in result:
metrics[key] = result[key]
if "decode_backend_trace" in result:
metrics["decode_backend_trace"] = result["decode_backend_trace"]
histogram = page_format_histogram_from_result(result)
if histogram:
metrics["page_format_histogram"] = histogram
return metrics
def build_evaluation_record(
metadata: EvaluationMetadata,
result: Mapping[str, Any],
*,
include_source_result: bool = True,
) -> EvaluationRecord:
metrics = derive_standard_metrics(result, metadata=metadata)
source_result = dict(result) if include_source_result else {}
return EvaluationRecord(metadata=metadata, metrics=metrics, source_result=source_result)
|