from __future__ import annotations from dataclasses import asdict, dataclass, field from typing import Any, Literal, Mapping EvaluationSplit = Literal["calibration", "held_out"] EvaluationLane = Literal["systems", "quality", "diagnostic"] PromptFamily = Literal["synthetic_exact_length", "held_out_natural_text", "standardized_long_context"] HarnessTruthType = Literal["reference_trace", "paged_runtime"] _PAGE_MODE_FIELDS: tuple[str, ...] = ("m0", "m1", "m2", "m3", "m4", "t3") def _coerce_non_negative_int(value: Any, *, field_name: str) -> int: parsed = int(value) if parsed < 0: raise ValueError(f"{field_name} must be non-negative") return parsed def _coerce_positive_int(value: Any, *, field_name: str) -> int: parsed = int(value) if parsed <= 0: raise ValueError(f"{field_name} must be positive") return parsed def page_format_histogram_from_result(result: Mapping[str, Any]) -> dict[str, int]: histogram: dict[str, int] = {} for kind_prefix in ("k", "v"): kind_label = kind_prefix.upper() for mode in _PAGE_MODE_FIELDS: key = f"{kind_prefix}_{mode}_pages" if key not in result: continue histogram[f"{kind_label}:{mode.upper()}"] = max(int(result.get(key, 0) or 0), 0) return histogram def _pick_context_length(result: Mapping[str, Any]) -> int | None: for key in ("context_length_effective", "context_length", "prompt_length", "sequence_length", "prefix_length"): value = result.get(key) if value is None: continue try: parsed = int(value) except (TypeError, ValueError): continue if parsed > 0: return parsed return None def _pick_float(result: Mapping[str, Any], *keys: str) -> float | None: for key in keys: if key not in result: continue value = result.get(key) if value is None: continue try: return float(value) except (TypeError, ValueError): continue return None @dataclass(slots=True) class EvaluationMetadata: model_id: str model_family: str backend: str device: str torch_dtype: str split: EvaluationSplit lane: EvaluationLane prompt_family: PromptFamily dataset_name: str prompt_count: int batch_size: int truth_type: HarnessTruthType effective_budget_rule: str context_length: int | None = None decode_steps: int | None = None eval_steps: int | None = None notes: list[str] = field(default_factory=list) def __post_init__(self) -> None: if not self.model_id: raise ValueError("model_id is required") if not self.model_family: raise ValueError("model_family is required") if not self.backend: raise ValueError("backend is required") if not self.device: raise ValueError("device is required") if not self.torch_dtype: raise ValueError("torch_dtype is required") if self.split not in ("calibration", "held_out"): raise ValueError("split must be calibration or held_out") if self.lane not in ("systems", "quality", "diagnostic"): raise ValueError("lane must be systems, quality, or diagnostic") if self.prompt_family not in ("synthetic_exact_length", "held_out_natural_text", "standardized_long_context"): raise ValueError("prompt_family is invalid") if not self.dataset_name: raise ValueError("dataset_name is required") self.prompt_count = _coerce_positive_int(self.prompt_count, field_name="prompt_count") self.batch_size = _coerce_positive_int(self.batch_size, field_name="batch_size") if not self.truth_type: raise ValueError("truth_type is required") if self.truth_type not in ("reference_trace", "paged_runtime"): raise ValueError("truth_type must be reference_trace or paged_runtime") if not self.effective_budget_rule: raise ValueError("effective_budget_rule is required") if self.context_length is not None: self.context_length = _coerce_positive_int(self.context_length, field_name="context_length") if self.decode_steps is not None: self.decode_steps = _coerce_non_negative_int(self.decode_steps, field_name="decode_steps") if self.eval_steps is not None: self.eval_steps = _coerce_non_negative_int(self.eval_steps, field_name="eval_steps") def to_dict(self) -> dict[str, Any]: return asdict(self) @dataclass(slots=True) class EvaluationRecord: metadata: EvaluationMetadata metrics: dict[str, Any] source_result: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict[str, Any]: return { "metadata": self.metadata.to_dict(), "metrics": dict(self.metrics), "source_result": dict(self.source_result), } def derive_standard_metrics(result: Mapping[str, Any], metadata: EvaluationMetadata | None = None) -> dict[str, Any]: metrics: dict[str, Any] = {} systems_keys = ( "dotcache_decode_ms_per_step", "ttft_ms", "p95_decode_ms_per_step", "resident_bytes", "prefill_ms", "execution_shortlist_selected_pages", "execution_shortlist_total_pages", ) for key in systems_keys: if key in result: metrics[key] = result[key] effective_bytes_per_token = _pick_float(result, "effective_bytes_per_token") if effective_bytes_per_token is None: resident_bytes = _pick_float(result, "resident_bytes") context_length = metadata.context_length if metadata is not None else None if context_length is None: context_length = _pick_context_length(result) if resident_bytes is not None and context_length is not None and context_length > 0: effective_bytes_per_token = resident_bytes / float(context_length) if effective_bytes_per_token is not None: metrics["effective_bytes_per_token"] = float(effective_bytes_per_token) quality_keys = ( "teacher_forced_loss_delta", "teacher_forced_perplexity_ratio", "teacher_forced_logit_max_abs_error", "teacher_forced_logit_mean_abs_error", "teacher_forced_logit_rmse", "teacher_forced_token_agreement_rate", "teacher_forced_target_match_rate", ) for key in quality_keys: if key in result: metrics[key] = result[key] diagnostic_keys = ( "shortlist_recall_exact_top_recall_mean", "execution_decode_shortlist_selection_ms_total", "execution_decode_shortlist_candidate_approx_scoring_ms_total", "execution_decode_shortlist_candidate_ranking_ms_total", "execution_decode_shortlist_materialization_ms_total", "execution_decode_backend_call_non_backend_ms_total", "replay_output_max_abs_error", ) for key in diagnostic_keys: if key in result: metrics[key] = result[key] if "decode_backend_trace" in result: metrics["decode_backend_trace"] = result["decode_backend_trace"] histogram = page_format_histogram_from_result(result) if histogram: metrics["page_format_histogram"] = histogram return metrics def build_evaluation_record( metadata: EvaluationMetadata, result: Mapping[str, Any], *, include_source_result: bool = True, ) -> EvaluationRecord: metrics = derive_standard_metrics(result, metadata=metadata) source_result = dict(result) if include_source_result else {} return EvaluationRecord(metadata=metadata, metrics=metrics, source_result=source_result)