linvest21's picture
download
raw
34.4 kB
from __future__ import annotations
import hashlib
import json
import re
from collections import Counter
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from data_pipeline.all_role_defect_repair import corrected_answer
from data_pipeline.learning_pdf_to_jsonl import write_jsonl
from data_pipeline.repair_answer_quality import answer_admitted_to_training, repair_quality_checks
from eval.paired_eval_defect_ranker import classify_prediction
from n21.config import write_json
from n21.settings import SHFT_WORKSPACE_ROOT
PREFERENCE_SCHEMA_VERSION = "shft_pairwise_preference_memory_v1"
THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", flags=re.IGNORECASE | re.DOTALL)
DIAGNOSTIC_BUCKETS = (
"leverage_math",
"margin_analysis",
"cash_flow_reasoning",
"eps_quality",
"style_fact_inference",
"neutral_language",
"risk_tradeoff",
"valuation_math",
"accounting_sec_extraction",
"moat_reasoning",
"risk_premium_discount_rate",
"investment_memo_synthesis",
"hallucination_uncertainty",
)
REPAIR_STRATEGIES = (
"generic_loss_targeted",
"hard_negative_dpo",
"critical_safety_repair",
"human_failure_repair",
"answer_quality_repair",
)
def utc_now() -> str:
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def read_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8-sig"))
def read_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8-sig") as handle:
for line_no, line in enumerate(handle, start=1):
if not line.strip():
continue
row = json.loads(line)
if not isinstance(row, dict):
raise ValueError(f"{path}:{line_no} must contain a JSON object")
rows.append(row)
return rows
def sha256_file(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def text_value(row: dict[str, Any], *keys: str) -> str:
for key in keys:
value = row.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def clean_model_response(text: str) -> str:
cleaned = THINK_BLOCK_RE.sub("", text).strip()
if "</think>" in cleaned.lower():
parts = re.split(r"</think>", cleaned, flags=re.IGNORECASE)
cleaned = parts[-1].strip()
cleaned = re.sub(r"</?think>", "", cleaned, flags=re.IGNORECASE).strip()
return cleaned
def score_number(value: Any) -> float | None:
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, dict):
raw = value.get("score")
if isinstance(raw, (int, float)):
return float(raw)
return None
def critical_pass(value: Any) -> bool | None:
if isinstance(value, dict) and isinstance(value.get("critical_pass"), bool):
return bool(value["critical_pass"])
return None
def candidate_lost(prediction: dict[str, Any]) -> bool:
outcome = str(
prediction.get("winner")
or prediction.get("pairwise_result")
or prediction.get("outcome")
or ""
).lower()
if outcome in {"baseline", "baseline_win", "baseline_wins", "candidate_loss", "loss"}:
return True
delta = prediction.get("delta")
if isinstance(delta, (int, float)):
return float(delta) < 0
candidate = score_number(prediction.get("candidate_score"))
baseline = score_number(prediction.get("baseline_score"))
return candidate is not None and baseline is not None and candidate < baseline
def candidate_critical_failed(prediction: dict[str, Any]) -> bool:
value = critical_pass(prediction.get("candidate_score"))
if value is not None:
return not value
raw = prediction.get("candidate_critical_pass")
return isinstance(raw, bool) and not raw
def normalized_repair_strategy(repair_strategy: str | None) -> str:
strategy = (repair_strategy or "generic_loss_targeted").strip()
if strategy not in REPAIR_STRATEGIES:
raise ValueError(f"unknown repair_strategy={strategy!r}; expected one of {', '.join(REPAIR_STRATEGIES)}")
return strategy
def prediction_matches_repair_strategy(
*,
prediction: dict[str, Any],
repair_strategy: str,
include_critical_failures: bool,
) -> bool:
"""Return whether a paired-eval row belongs in the selected repair lane.
The controller is allowed to choose a strategy that is narrower than the
old generic loss target. In particular, critical-safety repair must not be
diluted with ordinary pairwise losses that already pass critical checks.
"""
is_loss = candidate_lost(prediction)
is_critical = candidate_critical_failed(prediction)
if repair_strategy == "generic_loss_targeted":
return is_loss or (include_critical_failures and is_critical)
if repair_strategy == "critical_safety_repair":
return include_critical_failures and is_critical
if repair_strategy == "hard_negative_dpo":
return is_loss
if repair_strategy in {"human_failure_repair", "answer_quality_repair"}:
return is_loss or (include_critical_failures and is_critical)
return False
def defects_for_prediction(prediction: dict[str, Any]) -> list[str]:
classified = classify_prediction(prediction)
defects = list(classified.keys()) if isinstance(classified, dict) else list(classified)
if candidate_lost(prediction):
defects.append("pairwise_loss")
if candidate_critical_failed(prediction):
defects.append("critical_failure")
ordered: list[str] = []
for defect in defects:
if defect not in ordered:
ordered.append(defect)
return ordered or ["unclassified_pairwise_failure"]
def winner_for_prediction(prediction: dict[str, Any]) -> str:
if candidate_lost(prediction):
return "baseline"
delta = prediction.get("delta")
if isinstance(delta, (int, float)) and float(delta) > 0:
return "candidate"
candidate = score_number(prediction.get("candidate_score"))
baseline = score_number(prediction.get("baseline_score"))
if candidate is not None and baseline is not None:
if candidate > baseline:
return "candidate"
if candidate < baseline:
return "baseline"
return "tie"
def failure_bucket_for_prediction(prediction: dict[str, Any], defects: list[str] | None = None) -> str:
defects = defects or defects_for_prediction(prediction)
text = " ".join(
[
str(prediction.get("id") or ""),
str(prediction.get("task") or ""),
str(prediction.get("prompt") or ""),
str(prediction.get("candidate_answer") or ""),
" ".join(defects),
]
).lower()
if "hallucination" in defects or "unsupported" in text or "certainty" in text:
return "hallucination_uncertainty"
if any(term in text for term in ("debt-to-ebitda", "debt to ebitda", "ebitda", "leverage")):
return "leverage_math"
if any(term in text for term in ("gross margin", "operating margin", "margin picture")):
return "margin_analysis"
if any(term in text for term in ("free cash flow", "operating cash flow", "capex", "cash flow")):
return "cash_flow_reasoning"
if any(term in text for term in ("eps", "buyback", "tax rate", "earnings per share")):
return "eps_quality"
if "fact_inference_separation" in defects:
return "style_fact_inference"
if "risk_tradeoff_framing" in defects and "neutral" in text:
return "neutral_language"
if "risk_tradeoff_framing" in defects:
return "risk_tradeoff"
if any(term in text for term in ("discount rate", "risk premium", "wacc", "cost of equity", "terminal value")):
return "risk_premium_discount_rate"
if any(term in text for term in ("moat", "competitive", "pricing power", "switching cost", "durability")):
return "moat_reasoning"
if any(term in text for term in ("10-k", "10-q", "sec", "filing", "reported", "margin", "revenue", "backlog", "cash flow", "eps", "gaap")):
return "accounting_sec_extraction"
if "numeric_reasoning" in defects or any(term in text for term in ("valuation", "multiple", "ratio", "percent", "%", "growth", "margin")):
return "valuation_math"
return "investment_memo_synthesis"
def bucket_weighted_order(predictions: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Round-robin predictions across failure buckets.
First-come truncation under ``max_records`` lets one over-represented bucket
crowd out the rest. Interleaving by bucket means a tighter cap still keeps
coverage proportional across the diagnostic buckets instead of dropping
whole failure modes. Order within a bucket is preserved.
"""
buckets: dict[str, list[dict[str, Any]]] = {}
order: list[str] = []
for prediction in predictions:
bucket = failure_bucket_for_prediction(prediction)
if bucket not in buckets:
buckets[bucket] = []
order.append(bucket)
buckets[bucket].append(prediction)
interleaved: list[dict[str, Any]] = []
index = 0
while len(interleaved) < len(predictions):
for bucket in order:
items = buckets[bucket]
if index < len(items):
interleaved.append(items[index])
index += 1
return interleaved
def release_run_prefix(run_id: str) -> str:
match = re.match(r"^(run_.+?_v\d+_\d+)", run_id)
return match.group(1) if match else run_id
def collect_eligible_predictions(
*,
run_id: str,
predictions_path: Path,
include_critical_failures: bool,
include_historical: bool,
min_records: int,
max_records: int,
repair_strategy: str = "generic_loss_targeted",
) -> tuple[list[tuple[str, dict[str, Any]]], Counter[str]]:
repair_strategy = normalized_repair_strategy(repair_strategy)
skipped: Counter[str] = Counter()
eligible: list[tuple[str, dict[str, Any]]] = []
seen: set[tuple[str, str]] = set()
def add_from(path: Path, source_run_id: str) -> None:
for prediction in read_jsonl(path):
if not prediction_matches_repair_strategy(
prediction=prediction,
repair_strategy=repair_strategy,
include_critical_failures=include_critical_failures,
):
if repair_strategy == "generic_loss_targeted":
skipped["not_loss_or_critical_failure"] += 1
else:
skipped[f"strategy_filtered_{repair_strategy}"] += 1
continue
key = (source_run_id, str(prediction.get("id") or prediction.get("prompt") or ""))
if key in seen:
skipped["duplicate_prediction"] += 1
continue
seen.add(key)
eligible.append((source_run_id, prediction))
add_from(predictions_path, run_id)
if include_historical and len(eligible) < min(min_records, max_records):
prefix = release_run_prefix(run_id)
run_dirs = sorted((SHFT_WORKSPACE_ROOT / "runs").glob(f"{prefix}*"), key=lambda item: item.name, reverse=True)
for run_dir in run_dirs:
if len(eligible) >= min(min_records, max_records):
break
if run_dir.name == run_id or run_dir.name.endswith("_paired_eval_code"):
continue
historical_predictions = run_dir / "eval" / "paired_predictions.jsonl"
if historical_predictions.exists():
add_from(historical_predictions, run_dir.name)
return eligible, skipped
def judge_rationale_for_prediction(prediction: dict[str, Any], defects: list[str] | None = None) -> list[str]:
defects = defects or defects_for_prediction(prediction)
rationale: list[str] = []
winner = winner_for_prediction(prediction)
if winner == "baseline":
rationale.append("baseline scored higher than candidate on the paired frozen-eval rubric")
elif winner == "candidate":
rationale.append("candidate scored higher than baseline on the paired frozen-eval rubric")
else:
rationale.append("candidate tied baseline on the paired frozen-eval rubric")
candidate = score_number(prediction.get("candidate_score"))
baseline = score_number(prediction.get("baseline_score"))
if candidate is not None and baseline is not None:
rationale.append(f"candidate_score={candidate:.4f}; baseline_score={baseline:.4f}")
if candidate_critical_failed(prediction):
rationale.append("candidate failed at least one critical-pass condition")
checks = prediction.get("candidate_score", {}).get("checks")
if isinstance(checks, dict):
failed = [name for name, ok in checks.items() if ok is False]
if failed:
rationale.append("failed candidate checks: " + ", ".join(sorted(failed)))
if defects:
rationale.append("detected defects: " + ", ".join(defects))
return rationale
NUMERIC_TASKS = {"quantitative_qa"}
PREAMBLE_RE = re.compile(
r"^\s*(sure|certainly|of course|here(?:'s| is| are)|as an ai|i (?:can|will|would|am)|okay|ok|let me|i'd be happy)\b",
flags=re.IGNORECASE,
)
DIRECTION_RE = re.compile(
r"\b(to|from|increase[ds]?|decrease[ds]?|rose|fell|grew|declin\w*|higher|lower|delta|change[ds]?|ratio|margin|x)\b|%",
flags=re.IGNORECASE,
)
def prompt_requires_numeric(prompt: str, task: Any) -> bool:
if str(task) in NUMERIC_TASKS:
return True
return bool(re.search(r"\bcalculate\b|from \$?\d|to \$?\d|%|\bratio\b", prompt.lower()))
def repair_acceptance_checks(*, prompt: str, chosen: str, rejected: str, task: Any = None) -> dict[str, bool]:
quality = repair_quality_checks(prompt=prompt, answer=chosen, task=task)
return {
**quality,
"no_preamble": quality["no_chatty_preamble"],
"numeric_answer_for_numeric_prompt": quality["numeric_completion_when_required"],
"differs_from_rejected": chosen.strip() != rejected.strip(),
"meets_min_length": len(chosen.split()) >= 12,
}
def repair_admitted_to_training(checks: dict[str, bool]) -> bool:
return (
answer_admitted_to_training(checks)
and checks.get("differs_from_rejected") is True
and checks.get("meets_min_length") is True
)
def _numbers(prompt: str) -> list[float]:
values: list[float] = []
for raw in re.findall(r"\$?(\d+(?:\.\d+)?)%?", prompt):
try:
values.append(float(raw))
except ValueError:
continue
return values
def _number_anchor_text(prompt: str) -> str:
values = _numbers(prompt)
if not values:
return ""
return ", ".join(f"{value:g}" for value in values[:4])
def _ensure_terminal_numeric_conclusion(*, prompt: str, answer: str, task: Any = None) -> str:
checks = repair_quality_checks(prompt=prompt, answer=answer, task=task)
if checks.get("terminal_numeric_conclusion_when_required") is not False:
return answer
anchors = _number_anchor_text(prompt)
if not anchors:
return answer
return answer.rstrip() + f"\nConclusion numeric anchor: the final analyst conclusion should preserve {anchors} as the key quantitative evidence."
def _structured_reference_answer(*, prompt: str, bucket: str) -> str:
values = _numbers(prompt)
if bucket == "leverage_math" and len(values) >= 3:
before, after, ebitda = values[0], values[1], values[2]
before_ratio = before / ebitda if ebitda else 0.0
after_ratio = after / ebitda if ebitda else 0.0
delta = after_ratio - before_ratio
return (
f"Reported facts: Debt increased from ${before:g} million to ${after:g} million while EBITDA stayed flat at ${ebitda:g} million.\n"
f"Calculation: Debt/EBITDA moved from {before:g}/{ebitda:g} = {before_ratio:.3f}x to {after:g}/{ebitda:g} = {after_ratio:.3f}x, an increase of {delta:.3f}x.\n"
"Inference: The reported fact is higher debt with unchanged EBITDA; the inference is that leverage risk increased.\n"
"Risk/tradeoff: The risk is reduced balance-sheet flexibility if EBITDA weakens, while the offset is that the absolute leverage level must be judged against sector norms.\n"
f"Conclusion: Flag the leverage increase to {after_ratio:.3f}x, but do not infer distress from this fact alone."
)
if bucket == "margin_analysis" and len(values) >= 2:
return (
f"Reported facts: Gross margin improved from {values[0]:g}% to {values[1]:g}%, while operating margin fell because operating expenses increased.\n"
"Calculation: The gross-margin change is favorable, but operating expenses absorbed more than the gross-profit benefit at the operating-income line.\n"
"Inference: The reported facts separate product-level profitability from operating-cost discipline.\n"
"Risk/tradeoff: The risk is operating expense deleverage; the offset is that improved gross margin may still be valuable if expense growth normalizes.\n"
f"Conclusion: Treat the margin picture as mixed: gross margin improved to {values[1]:g}%, but operating-margin pressure remains the key risk."
)
if bucket == "eps_quality" and values:
return (
f"Reported facts: EPS rose {values[0]:g}% while revenue was flat, helped by buybacks and a lower tax rate.\n"
"Calculation: The EPS increase should be decomposed into operating earnings, share-count reduction, and tax-rate benefit rather than treated as pure operating growth.\n"
"Inference: The reported fact is EPS growth; the inference is that quality of growth may be weaker because non-operating or denominator effects contributed.\n"
"Risk/tradeoff: The risk is overstating earnings quality; the offset is that buybacks can still create value if funded sustainably and shares are attractively valued.\n"
f"Conclusion: Separate the {values[0]:g}% EPS gain from buyback and tax effects before upgrading the earnings-quality view."
)
if bucket == "cash_flow_reasoning" and values:
return (
f"Reported facts: Operating cash flow was positive, but free cash flow was negative because capex rose {values[0]:g}%.\n"
"Calculation: Free cash flow equals operating cash flow minus capex, so higher capex can turn positive operating cash generation into negative free cash flow.\n"
"Inference: The reported facts show investment spending consumed operating cash; whether that is good or bad depends on expected returns on the capex.\n"
"Risk/tradeoff: The risk is weaker near-term cash conversion; the offset is that growth capex may support future capacity or returns.\n"
f"Conclusion: Flag negative free cash flow and the {values[0]:g}% capex increase, then verify whether the spending is maintenance or growth investment."
)
number_text = ", ".join(f"{value:g}" for value in values[:4]) if values else "no explicit numeric value"
conclusion_anchor = f" using the numeric anchor {number_text}" if values else ""
return (
f"Reported facts: {prompt.strip()}\n"
f"Calculation: The answer should preserve the numeric anchor ({number_text}) and compare direction, magnitude, and denominator effects when applicable.\n"
"Inference: Separate what the prompt reports from any conclusion about durability, valuation, or future risk.\n"
"Risk/tradeoff: Identify the main risk, the offsetting consideration, and the next filing or metric needed to confirm the inference.\n"
f"Conclusion: Provide a neutral analyst conclusion{conclusion_anchor} without investment advice, unsupported certainty, or a one-sided call."
)
def synthesize_reference_answer(*, prediction: dict[str, Any], asset_class: str, role: str, defects: list[str]) -> str:
prompt = text_value(prediction, "prompt", "question", "input")
bucket = failure_bucket_for_prediction(prediction, defects)
structured = _structured_reference_answer(prompt=prompt, bucket=bucket)
if structured:
return structured
base = corrected_answer(
prompt=prompt,
asset_class=asset_class,
role=role,
defects=defects,
)
points = prediction.get("expected_points")
if isinstance(points, list):
bullets = [str(point).strip() for point in points if str(point).strip()]
if bullets:
base += "\n\nA complete answer should: " + "; ".join(bullets) + "."
return clean_model_response(base)
def corrected_chosen_answer(
*, prediction: dict[str, Any], asset_class: str, role: str, defects: list[str]
) -> tuple[str, str]:
"""Resolve the preferred (`chosen`) answer and its provenance.
Priority avoids the parity trap: a curated gold answer wins; the losing
candidate's baseline is only accepted when the baseline actually beat the
candidate (a legitimate "match the better answer" target). Otherwise a
rubric-grounded answer is synthesized so `chosen` is not capped at baseline.
"""
explicit = text_value(
prediction,
"human_corrected_answer",
"corrected_answer",
"reference_answer",
"gold_answer",
"expected_answer",
)
if explicit:
explicit = clean_model_response(explicit)
explicit = _ensure_terminal_numeric_conclusion(
prompt=text_value(prediction, "prompt", "question", "input"),
answer=explicit,
task=prediction.get("task"),
)
if explicit:
return explicit, "gold_answer"
baseline = text_value(prediction, "baseline_answer", "baseline_response")
if (
baseline
and winner_for_prediction(prediction) == "baseline"
and critical_pass(prediction.get("baseline_score")) is not False
):
baseline = clean_model_response(baseline)
prompt = text_value(prediction, "prompt", "question", "input")
baseline = _ensure_terminal_numeric_conclusion(prompt=prompt, answer=baseline, task=prediction.get("task"))
rejected = clean_model_response(text_value(prediction, "candidate_answer", "candidate_response", "answer"))
checks = repair_acceptance_checks(prompt=prompt, chosen=baseline, rejected=rejected, task=prediction.get("task"))
if baseline and repair_admitted_to_training(checks):
return baseline, "baseline_winning_answer"
return synthesize_reference_answer(prediction=prediction, asset_class=asset_class, role=role, defects=defects), "rubric_grounded_synthetic"
def validate_pair(pair: dict[str, Any]) -> list[str]:
errors: list[str] = []
for field in ("prompt", "chosen", "rejected"):
if not isinstance(pair.get(field), str) or not pair[field].strip():
errors.append(f"{field} must be non-empty text")
if "<think" in str(pair.get(field) or "").lower() or "</think" in str(pair.get(field) or "").lower():
errors.append(f"{field} must not contain think tags")
if pair.get("chosen") == pair.get("rejected"):
errors.append("chosen and rejected must differ")
metadata = pair.get("metadata")
if not isinstance(metadata, dict):
errors.append("metadata must be an object")
else:
defects = metadata.get("defect_types")
if not isinstance(defects, list) or not defects:
errors.append("metadata.defect_types must be non-empty")
if not metadata.get("source_run_id"):
errors.append("metadata.source_run_id must be present")
target = metadata.get("repair_target")
if not isinstance(target, dict):
errors.append("metadata.repair_target must be present")
elif target.get("admitted_to_training") is not True:
errors.append("metadata.repair_target.admitted_to_training must be true")
return errors
def build_pair(
*,
prediction: dict[str, Any],
asset_class: str,
role: str,
run_id: str,
source_run_id: str | None = None,
record_index: int,
include_reasoning_chosen: bool,
repair_strategy: str = "generic_loss_targeted",
) -> dict[str, Any] | None:
repair_strategy = normalized_repair_strategy(repair_strategy)
prompt = text_value(prediction, "prompt", "question", "input")
rejected = clean_model_response(text_value(prediction, "candidate_answer", "candidate_response", "answer"))
if not prompt or not rejected:
return None
defects = defects_for_prediction(prediction)
if include_reasoning_chosen:
chosen, chosen_source = corrected_chosen_answer(prediction=prediction, asset_class=asset_class, role=role, defects=defects)
else:
chosen = text_value(prediction, "baseline_answer", "baseline_response")
chosen_source = "baseline_answer"
if not chosen:
return None
chosen = clean_model_response(chosen)
bucket = failure_bucket_for_prediction(prediction, defects)
rationale = judge_rationale_for_prediction(prediction, defects)
acceptance = repair_acceptance_checks(prompt=prompt, chosen=chosen, rejected=rejected, task=prediction.get("task"))
return {
"schema_version": PREFERENCE_SCHEMA_VERSION,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
"metadata": {
"asset_class": asset_class,
"role": role,
"source_run_id": source_run_id or run_id,
"source_prediction_id": prediction.get("id"),
"source_task": prediction.get("task"),
"source_candidate_score": score_number(prediction.get("candidate_score")),
"source_baseline_score": score_number(prediction.get("baseline_score")),
"source_candidate_critical_pass": critical_pass(prediction.get("candidate_score")),
"source_delta": prediction.get("delta"),
"pairwise_loss": candidate_lost(prediction),
"critical_failure": candidate_critical_failed(prediction),
"defect_types": defects,
"failure_bucket": bucket,
"judge_rationale": rationale,
"chosen_source": chosen_source,
"chosen_is_gold": chosen_source == "gold_answer",
"parity_capped": chosen_source in {"baseline_answer", "baseline_winning_answer"},
"repair_target": {
"answer": chosen,
"source": chosen_source,
"acceptance_checks": acceptance,
"admitted_to_training": repair_admitted_to_training(acceptance),
},
"repair_strategy": repair_strategy,
"training_objective": "dpo_or_orpo_preference_optimization",
"record_index": record_index,
"created_at": utc_now(),
},
}
def write_markdown_summary(path: Path, result: dict[str, Any]) -> None:
summary = result["summary"]
lines = [
"# SHFT Pairwise Preference Memory",
"",
f"- Run: `{result['run_id']}`",
f"- Asset/role: `{result['asset_class']}/{result['role']}`",
f"- Repair strategy: `{summary['repair_strategy']}`",
f"- Preference pairs: `{summary['preference_pair_count']}`",
f"- Pairwise losses captured: `{summary['pairwise_loss_pair_count']}`",
f"- Critical failures captured: `{summary['critical_failure_pair_count']}`",
f"- Output: `{result['output_path']}`",
f"- SHA256: `{result['output_sha256']}`",
"",
"## Top Defects",
"",
]
for defect, count in sorted(summary["defect_counts"].items(), key=lambda item: (-item[1], item[0])):
lines.append(f"- `{defect}`: `{count}`")
lines.extend(
[
"",
"## Training Use",
"",
"Use these rows as preference data where `chosen` is the desired response and `rejected` is the losing candidate response.",
"This is the factual next SHFT input after a paired-eval failure; it targets measured pairwise losses instead of repeating another SFT breakout on the same corpus.",
"",
]
)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines), encoding="utf-8")
def build_pairwise_preference_data(
*,
run_id: str,
asset_class: str,
role: str,
predictions_path: Path | None = None,
output_path: Path | None = None,
max_records: int = 500,
min_records: int = 1,
include_historical: bool = True,
include_critical_failures: bool = True,
include_reasoning_chosen: bool = True,
repair_strategy: str = "generic_loss_targeted",
) -> dict[str, Any]:
if max_records <= 0:
raise ValueError("max_records must be positive")
repair_strategy = normalized_repair_strategy(repair_strategy)
run_path = SHFT_WORKSPACE_ROOT / "runs" / run_id
predictions = predictions_path or run_path / "eval" / "paired_predictions.jsonl"
if not predictions.exists():
raise FileNotFoundError(f"paired predictions not found: {predictions}")
report_path = predictions.parent / "paired_eval_report.json"
report = read_json(report_path) if report_path.exists() else {}
output = output_path or run_path / "preference_memory" / "preference_pairs.jsonl"
manifest_path = output.with_name("preference_manifest.json")
markdown_path = output.with_name("preference_memory_summary.md")
pairs: list[dict[str, Any]] = []
schema_errors: list[str] = []
skipped = Counter()
defect_counts: Counter[str] = Counter()
bucket_selection_counts: Counter[str] = Counter()
eligible_with_source, collection_skipped = collect_eligible_predictions(
run_id=run_id,
predictions_path=predictions,
include_critical_failures=include_critical_failures,
include_historical=include_historical,
min_records=min_records,
max_records=max_records,
repair_strategy=repair_strategy,
)
skipped.update(collection_skipped)
prediction_by_identity = {id(prediction): source_run_id for source_run_id, prediction in eligible_with_source}
ordered = bucket_weighted_order([prediction for _source_run_id, prediction in eligible_with_source])
for prediction in ordered:
if len(pairs) >= max_records:
skipped["max_records_reached"] += 1
continue
pair = build_pair(
prediction=prediction,
asset_class=asset_class,
role=role,
run_id=run_id,
source_run_id=prediction_by_identity.get(id(prediction), run_id),
record_index=len(pairs),
include_reasoning_chosen=include_reasoning_chosen,
repair_strategy=repair_strategy,
)
if pair is None:
skipped["missing_prompt_or_answers"] += 1
continue
if pair["metadata"]["repair_target"]["admitted_to_training"] is not True:
skipped["repair_target_not_admitted"] += 1
continue
errors = validate_pair(pair)
if errors:
schema_errors.extend(f"{pair.get('metadata', {}).get('source_prediction_id')}: {error}" for error in errors)
continue
pairs.append(pair)
defect_counts.update(pair["metadata"]["defect_types"])
bucket_selection_counts.update([pair["metadata"]["failure_bucket"]])
output.parent.mkdir(parents=True, exist_ok=True)
write_jsonl(output, pairs)
output_sha = sha256_file(output)
summary = {
"repair_strategy": repair_strategy,
"preference_pair_count": len(pairs),
"min_records": min_records,
"min_records_met": len(pairs) >= min_records,
"include_historical": include_historical,
"historical_pair_count": sum(1 for pair in pairs if pair["metadata"]["source_run_id"] != run_id),
"pairwise_loss_pair_count": sum(1 for pair in pairs if pair["metadata"]["pairwise_loss"]),
"critical_failure_pair_count": sum(1 for pair in pairs if pair["metadata"]["critical_failure"]),
"gold_chosen_pair_count": sum(1 for pair in pairs if pair["metadata"].get("chosen_is_gold")),
"baseline_capped_pair_count": sum(1 for pair in pairs if pair["metadata"].get("parity_capped")),
"admitted_pair_count": sum(1 for pair in pairs if pair["metadata"]["repair_target"]["admitted_to_training"]),
"selection_strategy": f"{repair_strategy}_bucket_weighted_round_robin",
"bucket_selection_counts": dict(bucket_selection_counts),
"defect_counts": dict(defect_counts),
"skipped": dict(skipped),
"schema_error_count": len(schema_errors),
"max_records": max_records,
"include_critical_failures": include_critical_failures,
"include_reasoning_chosen": include_reasoning_chosen,
}
result = {
"ok": len(pairs) >= min_records and not schema_errors,
"schema_version": PREFERENCE_SCHEMA_VERSION,
"run_id": run_id,
"asset_class": asset_class,
"role": role,
"predictions_path": str(predictions),
"paired_eval_report_path": str(report_path) if report_path.exists() else None,
"paired_eval_improvement": report.get("improvement", {}),
"output_path": str(output),
"output_sha256": output_sha,
"manifest_path": str(manifest_path),
"markdown_path": str(markdown_path),
"summary": summary,
"schema_errors": schema_errors,
"training_recommendation": {
"objective": "preference_optimization",
"preferred_algorithms": ["DPO", "ORPO", "KTO"],
"repair_strategy": repair_strategy,
"reason": f"paired evaluation produced explicit chosen/rejected failures; optimize the {repair_strategy} lane before paid SFT breakout retries",
},
"created_at": utc_now(),
}
write_json(manifest_path, result)
write_markdown_summary(markdown_path, result)
return result

Xet Storage Details

Size:
34.4 kB
·
Xet hash:
034a1b8abb8173eef197db0ffe4a616c9c9dfa0ae17dcf52d3a9e48afd800cae

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.