Buckets:
linvest21/shft-artifacts / code /self_healing_finetuning /data_pipeline /pairwise_preference_memory.py
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import re | |
| from collections import Counter | |
| from datetime import UTC, datetime | |
| from pathlib import Path | |
| from typing import Any | |
| from data_pipeline.all_role_defect_repair import corrected_answer | |
| from data_pipeline.learning_pdf_to_jsonl import write_jsonl | |
| from data_pipeline.repair_answer_quality import answer_admitted_to_training, repair_quality_checks | |
| from eval.paired_eval_defect_ranker import classify_prediction | |
| from n21.config import write_json | |
| from n21.settings import SHFT_WORKSPACE_ROOT | |
| PREFERENCE_SCHEMA_VERSION = "shft_pairwise_preference_memory_v1" | |
| THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", flags=re.IGNORECASE | re.DOTALL) | |
| DIAGNOSTIC_BUCKETS = ( | |
| "leverage_math", | |
| "margin_analysis", | |
| "cash_flow_reasoning", | |
| "eps_quality", | |
| "style_fact_inference", | |
| "neutral_language", | |
| "risk_tradeoff", | |
| "valuation_math", | |
| "accounting_sec_extraction", | |
| "moat_reasoning", | |
| "risk_premium_discount_rate", | |
| "investment_memo_synthesis", | |
| "hallucination_uncertainty", | |
| ) | |
| REPAIR_STRATEGIES = ( | |
| "generic_loss_targeted", | |
| "hard_negative_dpo", | |
| "critical_safety_repair", | |
| "human_failure_repair", | |
| "answer_quality_repair", | |
| ) | |
| def utc_now() -> str: | |
| return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") | |
| def read_json(path: Path) -> dict[str, Any]: | |
| return json.loads(path.read_text(encoding="utf-8-sig")) | |
| def read_jsonl(path: Path) -> list[dict[str, Any]]: | |
| rows: list[dict[str, Any]] = [] | |
| with path.open("r", encoding="utf-8-sig") as handle: | |
| for line_no, line in enumerate(handle, start=1): | |
| if not line.strip(): | |
| continue | |
| row = json.loads(line) | |
| if not isinstance(row, dict): | |
| raise ValueError(f"{path}:{line_no} must contain a JSON object") | |
| rows.append(row) | |
| return rows | |
| def sha256_file(path: Path) -> str: | |
| digest = hashlib.sha256() | |
| with path.open("rb") as handle: | |
| for chunk in iter(lambda: handle.read(1024 * 1024), b""): | |
| digest.update(chunk) | |
| return digest.hexdigest() | |
| def text_value(row: dict[str, Any], *keys: str) -> str: | |
| for key in keys: | |
| value = row.get(key) | |
| if isinstance(value, str) and value.strip(): | |
| return value.strip() | |
| return "" | |
| def clean_model_response(text: str) -> str: | |
| cleaned = THINK_BLOCK_RE.sub("", text).strip() | |
| if "</think>" in cleaned.lower(): | |
| parts = re.split(r"</think>", cleaned, flags=re.IGNORECASE) | |
| cleaned = parts[-1].strip() | |
| cleaned = re.sub(r"</?think>", "", cleaned, flags=re.IGNORECASE).strip() | |
| return cleaned | |
| def score_number(value: Any) -> float | None: | |
| if isinstance(value, (int, float)): | |
| return float(value) | |
| if isinstance(value, dict): | |
| raw = value.get("score") | |
| if isinstance(raw, (int, float)): | |
| return float(raw) | |
| return None | |
| def critical_pass(value: Any) -> bool | None: | |
| if isinstance(value, dict) and isinstance(value.get("critical_pass"), bool): | |
| return bool(value["critical_pass"]) | |
| return None | |
| def candidate_lost(prediction: dict[str, Any]) -> bool: | |
| outcome = str( | |
| prediction.get("winner") | |
| or prediction.get("pairwise_result") | |
| or prediction.get("outcome") | |
| or "" | |
| ).lower() | |
| if outcome in {"baseline", "baseline_win", "baseline_wins", "candidate_loss", "loss"}: | |
| return True | |
| delta = prediction.get("delta") | |
| if isinstance(delta, (int, float)): | |
| return float(delta) < 0 | |
| candidate = score_number(prediction.get("candidate_score")) | |
| baseline = score_number(prediction.get("baseline_score")) | |
| return candidate is not None and baseline is not None and candidate < baseline | |
| def candidate_critical_failed(prediction: dict[str, Any]) -> bool: | |
| value = critical_pass(prediction.get("candidate_score")) | |
| if value is not None: | |
| return not value | |
| raw = prediction.get("candidate_critical_pass") | |
| return isinstance(raw, bool) and not raw | |
| def normalized_repair_strategy(repair_strategy: str | None) -> str: | |
| strategy = (repair_strategy or "generic_loss_targeted").strip() | |
| if strategy not in REPAIR_STRATEGIES: | |
| raise ValueError(f"unknown repair_strategy={strategy!r}; expected one of {', '.join(REPAIR_STRATEGIES)}") | |
| return strategy | |
| def prediction_matches_repair_strategy( | |
| *, | |
| prediction: dict[str, Any], | |
| repair_strategy: str, | |
| include_critical_failures: bool, | |
| ) -> bool: | |
| """Return whether a paired-eval row belongs in the selected repair lane. | |
| The controller is allowed to choose a strategy that is narrower than the | |
| old generic loss target. In particular, critical-safety repair must not be | |
| diluted with ordinary pairwise losses that already pass critical checks. | |
| """ | |
| is_loss = candidate_lost(prediction) | |
| is_critical = candidate_critical_failed(prediction) | |
| if repair_strategy == "generic_loss_targeted": | |
| return is_loss or (include_critical_failures and is_critical) | |
| if repair_strategy == "critical_safety_repair": | |
| return include_critical_failures and is_critical | |
| if repair_strategy == "hard_negative_dpo": | |
| return is_loss | |
| if repair_strategy in {"human_failure_repair", "answer_quality_repair"}: | |
| return is_loss or (include_critical_failures and is_critical) | |
| return False | |
| def defects_for_prediction(prediction: dict[str, Any]) -> list[str]: | |
| classified = classify_prediction(prediction) | |
| defects = list(classified.keys()) if isinstance(classified, dict) else list(classified) | |
| if candidate_lost(prediction): | |
| defects.append("pairwise_loss") | |
| if candidate_critical_failed(prediction): | |
| defects.append("critical_failure") | |
| ordered: list[str] = [] | |
| for defect in defects: | |
| if defect not in ordered: | |
| ordered.append(defect) | |
| return ordered or ["unclassified_pairwise_failure"] | |
| def winner_for_prediction(prediction: dict[str, Any]) -> str: | |
| if candidate_lost(prediction): | |
| return "baseline" | |
| delta = prediction.get("delta") | |
| if isinstance(delta, (int, float)) and float(delta) > 0: | |
| return "candidate" | |
| candidate = score_number(prediction.get("candidate_score")) | |
| baseline = score_number(prediction.get("baseline_score")) | |
| if candidate is not None and baseline is not None: | |
| if candidate > baseline: | |
| return "candidate" | |
| if candidate < baseline: | |
| return "baseline" | |
| return "tie" | |
| def failure_bucket_for_prediction(prediction: dict[str, Any], defects: list[str] | None = None) -> str: | |
| defects = defects or defects_for_prediction(prediction) | |
| text = " ".join( | |
| [ | |
| str(prediction.get("id") or ""), | |
| str(prediction.get("task") or ""), | |
| str(prediction.get("prompt") or ""), | |
| str(prediction.get("candidate_answer") or ""), | |
| " ".join(defects), | |
| ] | |
| ).lower() | |
| if "hallucination" in defects or "unsupported" in text or "certainty" in text: | |
| return "hallucination_uncertainty" | |
| if any(term in text for term in ("debt-to-ebitda", "debt to ebitda", "ebitda", "leverage")): | |
| return "leverage_math" | |
| if any(term in text for term in ("gross margin", "operating margin", "margin picture")): | |
| return "margin_analysis" | |
| if any(term in text for term in ("free cash flow", "operating cash flow", "capex", "cash flow")): | |
| return "cash_flow_reasoning" | |
| if any(term in text for term in ("eps", "buyback", "tax rate", "earnings per share")): | |
| return "eps_quality" | |
| if "fact_inference_separation" in defects: | |
| return "style_fact_inference" | |
| if "risk_tradeoff_framing" in defects and "neutral" in text: | |
| return "neutral_language" | |
| if "risk_tradeoff_framing" in defects: | |
| return "risk_tradeoff" | |
| if any(term in text for term in ("discount rate", "risk premium", "wacc", "cost of equity", "terminal value")): | |
| return "risk_premium_discount_rate" | |
| if any(term in text for term in ("moat", "competitive", "pricing power", "switching cost", "durability")): | |
| return "moat_reasoning" | |
| if any(term in text for term in ("10-k", "10-q", "sec", "filing", "reported", "margin", "revenue", "backlog", "cash flow", "eps", "gaap")): | |
| return "accounting_sec_extraction" | |
| if "numeric_reasoning" in defects or any(term in text for term in ("valuation", "multiple", "ratio", "percent", "%", "growth", "margin")): | |
| return "valuation_math" | |
| return "investment_memo_synthesis" | |
| def bucket_weighted_order(predictions: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| """Round-robin predictions across failure buckets. | |
| First-come truncation under ``max_records`` lets one over-represented bucket | |
| crowd out the rest. Interleaving by bucket means a tighter cap still keeps | |
| coverage proportional across the diagnostic buckets instead of dropping | |
| whole failure modes. Order within a bucket is preserved. | |
| """ | |
| buckets: dict[str, list[dict[str, Any]]] = {} | |
| order: list[str] = [] | |
| for prediction in predictions: | |
| bucket = failure_bucket_for_prediction(prediction) | |
| if bucket not in buckets: | |
| buckets[bucket] = [] | |
| order.append(bucket) | |
| buckets[bucket].append(prediction) | |
| interleaved: list[dict[str, Any]] = [] | |
| index = 0 | |
| while len(interleaved) < len(predictions): | |
| for bucket in order: | |
| items = buckets[bucket] | |
| if index < len(items): | |
| interleaved.append(items[index]) | |
| index += 1 | |
| return interleaved | |
| def release_run_prefix(run_id: str) -> str: | |
| match = re.match(r"^(run_.+?_v\d+_\d+)", run_id) | |
| return match.group(1) if match else run_id | |
| def collect_eligible_predictions( | |
| *, | |
| run_id: str, | |
| predictions_path: Path, | |
| include_critical_failures: bool, | |
| include_historical: bool, | |
| min_records: int, | |
| max_records: int, | |
| repair_strategy: str = "generic_loss_targeted", | |
| ) -> tuple[list[tuple[str, dict[str, Any]]], Counter[str]]: | |
| repair_strategy = normalized_repair_strategy(repair_strategy) | |
| skipped: Counter[str] = Counter() | |
| eligible: list[tuple[str, dict[str, Any]]] = [] | |
| seen: set[tuple[str, str]] = set() | |
| def add_from(path: Path, source_run_id: str) -> None: | |
| for prediction in read_jsonl(path): | |
| if not prediction_matches_repair_strategy( | |
| prediction=prediction, | |
| repair_strategy=repair_strategy, | |
| include_critical_failures=include_critical_failures, | |
| ): | |
| if repair_strategy == "generic_loss_targeted": | |
| skipped["not_loss_or_critical_failure"] += 1 | |
| else: | |
| skipped[f"strategy_filtered_{repair_strategy}"] += 1 | |
| continue | |
| key = (source_run_id, str(prediction.get("id") or prediction.get("prompt") or "")) | |
| if key in seen: | |
| skipped["duplicate_prediction"] += 1 | |
| continue | |
| seen.add(key) | |
| eligible.append((source_run_id, prediction)) | |
| add_from(predictions_path, run_id) | |
| if include_historical and len(eligible) < min(min_records, max_records): | |
| prefix = release_run_prefix(run_id) | |
| run_dirs = sorted((SHFT_WORKSPACE_ROOT / "runs").glob(f"{prefix}*"), key=lambda item: item.name, reverse=True) | |
| for run_dir in run_dirs: | |
| if len(eligible) >= min(min_records, max_records): | |
| break | |
| if run_dir.name == run_id or run_dir.name.endswith("_paired_eval_code"): | |
| continue | |
| historical_predictions = run_dir / "eval" / "paired_predictions.jsonl" | |
| if historical_predictions.exists(): | |
| add_from(historical_predictions, run_dir.name) | |
| return eligible, skipped | |
| def judge_rationale_for_prediction(prediction: dict[str, Any], defects: list[str] | None = None) -> list[str]: | |
| defects = defects or defects_for_prediction(prediction) | |
| rationale: list[str] = [] | |
| winner = winner_for_prediction(prediction) | |
| if winner == "baseline": | |
| rationale.append("baseline scored higher than candidate on the paired frozen-eval rubric") | |
| elif winner == "candidate": | |
| rationale.append("candidate scored higher than baseline on the paired frozen-eval rubric") | |
| else: | |
| rationale.append("candidate tied baseline on the paired frozen-eval rubric") | |
| candidate = score_number(prediction.get("candidate_score")) | |
| baseline = score_number(prediction.get("baseline_score")) | |
| if candidate is not None and baseline is not None: | |
| rationale.append(f"candidate_score={candidate:.4f}; baseline_score={baseline:.4f}") | |
| if candidate_critical_failed(prediction): | |
| rationale.append("candidate failed at least one critical-pass condition") | |
| checks = prediction.get("candidate_score", {}).get("checks") | |
| if isinstance(checks, dict): | |
| failed = [name for name, ok in checks.items() if ok is False] | |
| if failed: | |
| rationale.append("failed candidate checks: " + ", ".join(sorted(failed))) | |
| if defects: | |
| rationale.append("detected defects: " + ", ".join(defects)) | |
| return rationale | |
| NUMERIC_TASKS = {"quantitative_qa"} | |
| PREAMBLE_RE = re.compile( | |
| r"^\s*(sure|certainly|of course|here(?:'s| is| are)|as an ai|i (?:can|will|would|am)|okay|ok|let me|i'd be happy)\b", | |
| flags=re.IGNORECASE, | |
| ) | |
| DIRECTION_RE = re.compile( | |
| r"\b(to|from|increase[ds]?|decrease[ds]?|rose|fell|grew|declin\w*|higher|lower|delta|change[ds]?|ratio|margin|x)\b|%", | |
| flags=re.IGNORECASE, | |
| ) | |
| def prompt_requires_numeric(prompt: str, task: Any) -> bool: | |
| if str(task) in NUMERIC_TASKS: | |
| return True | |
| return bool(re.search(r"\bcalculate\b|from \$?\d|to \$?\d|%|\bratio\b", prompt.lower())) | |
| def repair_acceptance_checks(*, prompt: str, chosen: str, rejected: str, task: Any = None) -> dict[str, bool]: | |
| quality = repair_quality_checks(prompt=prompt, answer=chosen, task=task) | |
| return { | |
| **quality, | |
| "no_preamble": quality["no_chatty_preamble"], | |
| "numeric_answer_for_numeric_prompt": quality["numeric_completion_when_required"], | |
| "differs_from_rejected": chosen.strip() != rejected.strip(), | |
| "meets_min_length": len(chosen.split()) >= 12, | |
| } | |
| def repair_admitted_to_training(checks: dict[str, bool]) -> bool: | |
| return ( | |
| answer_admitted_to_training(checks) | |
| and checks.get("differs_from_rejected") is True | |
| and checks.get("meets_min_length") is True | |
| ) | |
| def _numbers(prompt: str) -> list[float]: | |
| values: list[float] = [] | |
| for raw in re.findall(r"\$?(\d+(?:\.\d+)?)%?", prompt): | |
| try: | |
| values.append(float(raw)) | |
| except ValueError: | |
| continue | |
| return values | |
| def _number_anchor_text(prompt: str) -> str: | |
| values = _numbers(prompt) | |
| if not values: | |
| return "" | |
| return ", ".join(f"{value:g}" for value in values[:4]) | |
| def _ensure_terminal_numeric_conclusion(*, prompt: str, answer: str, task: Any = None) -> str: | |
| checks = repair_quality_checks(prompt=prompt, answer=answer, task=task) | |
| if checks.get("terminal_numeric_conclusion_when_required") is not False: | |
| return answer | |
| anchors = _number_anchor_text(prompt) | |
| if not anchors: | |
| return answer | |
| return answer.rstrip() + f"\nConclusion numeric anchor: the final analyst conclusion should preserve {anchors} as the key quantitative evidence." | |
| def _structured_reference_answer(*, prompt: str, bucket: str) -> str: | |
| values = _numbers(prompt) | |
| if bucket == "leverage_math" and len(values) >= 3: | |
| before, after, ebitda = values[0], values[1], values[2] | |
| before_ratio = before / ebitda if ebitda else 0.0 | |
| after_ratio = after / ebitda if ebitda else 0.0 | |
| delta = after_ratio - before_ratio | |
| return ( | |
| f"Reported facts: Debt increased from ${before:g} million to ${after:g} million while EBITDA stayed flat at ${ebitda:g} million.\n" | |
| f"Calculation: Debt/EBITDA moved from {before:g}/{ebitda:g} = {before_ratio:.3f}x to {after:g}/{ebitda:g} = {after_ratio:.3f}x, an increase of {delta:.3f}x.\n" | |
| "Inference: The reported fact is higher debt with unchanged EBITDA; the inference is that leverage risk increased.\n" | |
| "Risk/tradeoff: The risk is reduced balance-sheet flexibility if EBITDA weakens, while the offset is that the absolute leverage level must be judged against sector norms.\n" | |
| f"Conclusion: Flag the leverage increase to {after_ratio:.3f}x, but do not infer distress from this fact alone." | |
| ) | |
| if bucket == "margin_analysis" and len(values) >= 2: | |
| return ( | |
| f"Reported facts: Gross margin improved from {values[0]:g}% to {values[1]:g}%, while operating margin fell because operating expenses increased.\n" | |
| "Calculation: The gross-margin change is favorable, but operating expenses absorbed more than the gross-profit benefit at the operating-income line.\n" | |
| "Inference: The reported facts separate product-level profitability from operating-cost discipline.\n" | |
| "Risk/tradeoff: The risk is operating expense deleverage; the offset is that improved gross margin may still be valuable if expense growth normalizes.\n" | |
| f"Conclusion: Treat the margin picture as mixed: gross margin improved to {values[1]:g}%, but operating-margin pressure remains the key risk." | |
| ) | |
| if bucket == "eps_quality" and values: | |
| return ( | |
| f"Reported facts: EPS rose {values[0]:g}% while revenue was flat, helped by buybacks and a lower tax rate.\n" | |
| "Calculation: The EPS increase should be decomposed into operating earnings, share-count reduction, and tax-rate benefit rather than treated as pure operating growth.\n" | |
| "Inference: The reported fact is EPS growth; the inference is that quality of growth may be weaker because non-operating or denominator effects contributed.\n" | |
| "Risk/tradeoff: The risk is overstating earnings quality; the offset is that buybacks can still create value if funded sustainably and shares are attractively valued.\n" | |
| f"Conclusion: Separate the {values[0]:g}% EPS gain from buyback and tax effects before upgrading the earnings-quality view." | |
| ) | |
| if bucket == "cash_flow_reasoning" and values: | |
| return ( | |
| f"Reported facts: Operating cash flow was positive, but free cash flow was negative because capex rose {values[0]:g}%.\n" | |
| "Calculation: Free cash flow equals operating cash flow minus capex, so higher capex can turn positive operating cash generation into negative free cash flow.\n" | |
| "Inference: The reported facts show investment spending consumed operating cash; whether that is good or bad depends on expected returns on the capex.\n" | |
| "Risk/tradeoff: The risk is weaker near-term cash conversion; the offset is that growth capex may support future capacity or returns.\n" | |
| f"Conclusion: Flag negative free cash flow and the {values[0]:g}% capex increase, then verify whether the spending is maintenance or growth investment." | |
| ) | |
| number_text = ", ".join(f"{value:g}" for value in values[:4]) if values else "no explicit numeric value" | |
| conclusion_anchor = f" using the numeric anchor {number_text}" if values else "" | |
| return ( | |
| f"Reported facts: {prompt.strip()}\n" | |
| f"Calculation: The answer should preserve the numeric anchor ({number_text}) and compare direction, magnitude, and denominator effects when applicable.\n" | |
| "Inference: Separate what the prompt reports from any conclusion about durability, valuation, or future risk.\n" | |
| "Risk/tradeoff: Identify the main risk, the offsetting consideration, and the next filing or metric needed to confirm the inference.\n" | |
| f"Conclusion: Provide a neutral analyst conclusion{conclusion_anchor} without investment advice, unsupported certainty, or a one-sided call." | |
| ) | |
| def synthesize_reference_answer(*, prediction: dict[str, Any], asset_class: str, role: str, defects: list[str]) -> str: | |
| prompt = text_value(prediction, "prompt", "question", "input") | |
| bucket = failure_bucket_for_prediction(prediction, defects) | |
| structured = _structured_reference_answer(prompt=prompt, bucket=bucket) | |
| if structured: | |
| return structured | |
| base = corrected_answer( | |
| prompt=prompt, | |
| asset_class=asset_class, | |
| role=role, | |
| defects=defects, | |
| ) | |
| points = prediction.get("expected_points") | |
| if isinstance(points, list): | |
| bullets = [str(point).strip() for point in points if str(point).strip()] | |
| if bullets: | |
| base += "\n\nA complete answer should: " + "; ".join(bullets) + "." | |
| return clean_model_response(base) | |
| def corrected_chosen_answer( | |
| *, prediction: dict[str, Any], asset_class: str, role: str, defects: list[str] | |
| ) -> tuple[str, str]: | |
| """Resolve the preferred (`chosen`) answer and its provenance. | |
| Priority avoids the parity trap: a curated gold answer wins; the losing | |
| candidate's baseline is only accepted when the baseline actually beat the | |
| candidate (a legitimate "match the better answer" target). Otherwise a | |
| rubric-grounded answer is synthesized so `chosen` is not capped at baseline. | |
| """ | |
| explicit = text_value( | |
| prediction, | |
| "human_corrected_answer", | |
| "corrected_answer", | |
| "reference_answer", | |
| "gold_answer", | |
| "expected_answer", | |
| ) | |
| if explicit: | |
| explicit = clean_model_response(explicit) | |
| explicit = _ensure_terminal_numeric_conclusion( | |
| prompt=text_value(prediction, "prompt", "question", "input"), | |
| answer=explicit, | |
| task=prediction.get("task"), | |
| ) | |
| if explicit: | |
| return explicit, "gold_answer" | |
| baseline = text_value(prediction, "baseline_answer", "baseline_response") | |
| if ( | |
| baseline | |
| and winner_for_prediction(prediction) == "baseline" | |
| and critical_pass(prediction.get("baseline_score")) is not False | |
| ): | |
| baseline = clean_model_response(baseline) | |
| prompt = text_value(prediction, "prompt", "question", "input") | |
| baseline = _ensure_terminal_numeric_conclusion(prompt=prompt, answer=baseline, task=prediction.get("task")) | |
| rejected = clean_model_response(text_value(prediction, "candidate_answer", "candidate_response", "answer")) | |
| checks = repair_acceptance_checks(prompt=prompt, chosen=baseline, rejected=rejected, task=prediction.get("task")) | |
| if baseline and repair_admitted_to_training(checks): | |
| return baseline, "baseline_winning_answer" | |
| return synthesize_reference_answer(prediction=prediction, asset_class=asset_class, role=role, defects=defects), "rubric_grounded_synthetic" | |
| def validate_pair(pair: dict[str, Any]) -> list[str]: | |
| errors: list[str] = [] | |
| for field in ("prompt", "chosen", "rejected"): | |
| if not isinstance(pair.get(field), str) or not pair[field].strip(): | |
| errors.append(f"{field} must be non-empty text") | |
| if "<think" in str(pair.get(field) or "").lower() or "</think" in str(pair.get(field) or "").lower(): | |
| errors.append(f"{field} must not contain think tags") | |
| if pair.get("chosen") == pair.get("rejected"): | |
| errors.append("chosen and rejected must differ") | |
| metadata = pair.get("metadata") | |
| if not isinstance(metadata, dict): | |
| errors.append("metadata must be an object") | |
| else: | |
| defects = metadata.get("defect_types") | |
| if not isinstance(defects, list) or not defects: | |
| errors.append("metadata.defect_types must be non-empty") | |
| if not metadata.get("source_run_id"): | |
| errors.append("metadata.source_run_id must be present") | |
| target = metadata.get("repair_target") | |
| if not isinstance(target, dict): | |
| errors.append("metadata.repair_target must be present") | |
| elif target.get("admitted_to_training") is not True: | |
| errors.append("metadata.repair_target.admitted_to_training must be true") | |
| return errors | |
| def build_pair( | |
| *, | |
| prediction: dict[str, Any], | |
| asset_class: str, | |
| role: str, | |
| run_id: str, | |
| source_run_id: str | None = None, | |
| record_index: int, | |
| include_reasoning_chosen: bool, | |
| repair_strategy: str = "generic_loss_targeted", | |
| ) -> dict[str, Any] | None: | |
| repair_strategy = normalized_repair_strategy(repair_strategy) | |
| prompt = text_value(prediction, "prompt", "question", "input") | |
| rejected = clean_model_response(text_value(prediction, "candidate_answer", "candidate_response", "answer")) | |
| if not prompt or not rejected: | |
| return None | |
| defects = defects_for_prediction(prediction) | |
| if include_reasoning_chosen: | |
| chosen, chosen_source = corrected_chosen_answer(prediction=prediction, asset_class=asset_class, role=role, defects=defects) | |
| else: | |
| chosen = text_value(prediction, "baseline_answer", "baseline_response") | |
| chosen_source = "baseline_answer" | |
| if not chosen: | |
| return None | |
| chosen = clean_model_response(chosen) | |
| bucket = failure_bucket_for_prediction(prediction, defects) | |
| rationale = judge_rationale_for_prediction(prediction, defects) | |
| acceptance = repair_acceptance_checks(prompt=prompt, chosen=chosen, rejected=rejected, task=prediction.get("task")) | |
| return { | |
| "schema_version": PREFERENCE_SCHEMA_VERSION, | |
| "prompt": prompt, | |
| "chosen": chosen, | |
| "rejected": rejected, | |
| "metadata": { | |
| "asset_class": asset_class, | |
| "role": role, | |
| "source_run_id": source_run_id or run_id, | |
| "source_prediction_id": prediction.get("id"), | |
| "source_task": prediction.get("task"), | |
| "source_candidate_score": score_number(prediction.get("candidate_score")), | |
| "source_baseline_score": score_number(prediction.get("baseline_score")), | |
| "source_candidate_critical_pass": critical_pass(prediction.get("candidate_score")), | |
| "source_delta": prediction.get("delta"), | |
| "pairwise_loss": candidate_lost(prediction), | |
| "critical_failure": candidate_critical_failed(prediction), | |
| "defect_types": defects, | |
| "failure_bucket": bucket, | |
| "judge_rationale": rationale, | |
| "chosen_source": chosen_source, | |
| "chosen_is_gold": chosen_source == "gold_answer", | |
| "parity_capped": chosen_source in {"baseline_answer", "baseline_winning_answer"}, | |
| "repair_target": { | |
| "answer": chosen, | |
| "source": chosen_source, | |
| "acceptance_checks": acceptance, | |
| "admitted_to_training": repair_admitted_to_training(acceptance), | |
| }, | |
| "repair_strategy": repair_strategy, | |
| "training_objective": "dpo_or_orpo_preference_optimization", | |
| "record_index": record_index, | |
| "created_at": utc_now(), | |
| }, | |
| } | |
| def write_markdown_summary(path: Path, result: dict[str, Any]) -> None: | |
| summary = result["summary"] | |
| lines = [ | |
| "# SHFT Pairwise Preference Memory", | |
| "", | |
| f"- Run: `{result['run_id']}`", | |
| f"- Asset/role: `{result['asset_class']}/{result['role']}`", | |
| f"- Repair strategy: `{summary['repair_strategy']}`", | |
| f"- Preference pairs: `{summary['preference_pair_count']}`", | |
| f"- Pairwise losses captured: `{summary['pairwise_loss_pair_count']}`", | |
| f"- Critical failures captured: `{summary['critical_failure_pair_count']}`", | |
| f"- Output: `{result['output_path']}`", | |
| f"- SHA256: `{result['output_sha256']}`", | |
| "", | |
| "## Top Defects", | |
| "", | |
| ] | |
| for defect, count in sorted(summary["defect_counts"].items(), key=lambda item: (-item[1], item[0])): | |
| lines.append(f"- `{defect}`: `{count}`") | |
| lines.extend( | |
| [ | |
| "", | |
| "## Training Use", | |
| "", | |
| "Use these rows as preference data where `chosen` is the desired response and `rejected` is the losing candidate response.", | |
| "This is the factual next SHFT input after a paired-eval failure; it targets measured pairwise losses instead of repeating another SFT breakout on the same corpus.", | |
| "", | |
| ] | |
| ) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text("\n".join(lines), encoding="utf-8") | |
| def build_pairwise_preference_data( | |
| *, | |
| run_id: str, | |
| asset_class: str, | |
| role: str, | |
| predictions_path: Path | None = None, | |
| output_path: Path | None = None, | |
| max_records: int = 500, | |
| min_records: int = 1, | |
| include_historical: bool = True, | |
| include_critical_failures: bool = True, | |
| include_reasoning_chosen: bool = True, | |
| repair_strategy: str = "generic_loss_targeted", | |
| ) -> dict[str, Any]: | |
| if max_records <= 0: | |
| raise ValueError("max_records must be positive") | |
| repair_strategy = normalized_repair_strategy(repair_strategy) | |
| run_path = SHFT_WORKSPACE_ROOT / "runs" / run_id | |
| predictions = predictions_path or run_path / "eval" / "paired_predictions.jsonl" | |
| if not predictions.exists(): | |
| raise FileNotFoundError(f"paired predictions not found: {predictions}") | |
| report_path = predictions.parent / "paired_eval_report.json" | |
| report = read_json(report_path) if report_path.exists() else {} | |
| output = output_path or run_path / "preference_memory" / "preference_pairs.jsonl" | |
| manifest_path = output.with_name("preference_manifest.json") | |
| markdown_path = output.with_name("preference_memory_summary.md") | |
| pairs: list[dict[str, Any]] = [] | |
| schema_errors: list[str] = [] | |
| skipped = Counter() | |
| defect_counts: Counter[str] = Counter() | |
| bucket_selection_counts: Counter[str] = Counter() | |
| eligible_with_source, collection_skipped = collect_eligible_predictions( | |
| run_id=run_id, | |
| predictions_path=predictions, | |
| include_critical_failures=include_critical_failures, | |
| include_historical=include_historical, | |
| min_records=min_records, | |
| max_records=max_records, | |
| repair_strategy=repair_strategy, | |
| ) | |
| skipped.update(collection_skipped) | |
| prediction_by_identity = {id(prediction): source_run_id for source_run_id, prediction in eligible_with_source} | |
| ordered = bucket_weighted_order([prediction for _source_run_id, prediction in eligible_with_source]) | |
| for prediction in ordered: | |
| if len(pairs) >= max_records: | |
| skipped["max_records_reached"] += 1 | |
| continue | |
| pair = build_pair( | |
| prediction=prediction, | |
| asset_class=asset_class, | |
| role=role, | |
| run_id=run_id, | |
| source_run_id=prediction_by_identity.get(id(prediction), run_id), | |
| record_index=len(pairs), | |
| include_reasoning_chosen=include_reasoning_chosen, | |
| repair_strategy=repair_strategy, | |
| ) | |
| if pair is None: | |
| skipped["missing_prompt_or_answers"] += 1 | |
| continue | |
| if pair["metadata"]["repair_target"]["admitted_to_training"] is not True: | |
| skipped["repair_target_not_admitted"] += 1 | |
| continue | |
| errors = validate_pair(pair) | |
| if errors: | |
| schema_errors.extend(f"{pair.get('metadata', {}).get('source_prediction_id')}: {error}" for error in errors) | |
| continue | |
| pairs.append(pair) | |
| defect_counts.update(pair["metadata"]["defect_types"]) | |
| bucket_selection_counts.update([pair["metadata"]["failure_bucket"]]) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| write_jsonl(output, pairs) | |
| output_sha = sha256_file(output) | |
| summary = { | |
| "repair_strategy": repair_strategy, | |
| "preference_pair_count": len(pairs), | |
| "min_records": min_records, | |
| "min_records_met": len(pairs) >= min_records, | |
| "include_historical": include_historical, | |
| "historical_pair_count": sum(1 for pair in pairs if pair["metadata"]["source_run_id"] != run_id), | |
| "pairwise_loss_pair_count": sum(1 for pair in pairs if pair["metadata"]["pairwise_loss"]), | |
| "critical_failure_pair_count": sum(1 for pair in pairs if pair["metadata"]["critical_failure"]), | |
| "gold_chosen_pair_count": sum(1 for pair in pairs if pair["metadata"].get("chosen_is_gold")), | |
| "baseline_capped_pair_count": sum(1 for pair in pairs if pair["metadata"].get("parity_capped")), | |
| "admitted_pair_count": sum(1 for pair in pairs if pair["metadata"]["repair_target"]["admitted_to_training"]), | |
| "selection_strategy": f"{repair_strategy}_bucket_weighted_round_robin", | |
| "bucket_selection_counts": dict(bucket_selection_counts), | |
| "defect_counts": dict(defect_counts), | |
| "skipped": dict(skipped), | |
| "schema_error_count": len(schema_errors), | |
| "max_records": max_records, | |
| "include_critical_failures": include_critical_failures, | |
| "include_reasoning_chosen": include_reasoning_chosen, | |
| } | |
| result = { | |
| "ok": len(pairs) >= min_records and not schema_errors, | |
| "schema_version": PREFERENCE_SCHEMA_VERSION, | |
| "run_id": run_id, | |
| "asset_class": asset_class, | |
| "role": role, | |
| "predictions_path": str(predictions), | |
| "paired_eval_report_path": str(report_path) if report_path.exists() else None, | |
| "paired_eval_improvement": report.get("improvement", {}), | |
| "output_path": str(output), | |
| "output_sha256": output_sha, | |
| "manifest_path": str(manifest_path), | |
| "markdown_path": str(markdown_path), | |
| "summary": summary, | |
| "schema_errors": schema_errors, | |
| "training_recommendation": { | |
| "objective": "preference_optimization", | |
| "preferred_algorithms": ["DPO", "ORPO", "KTO"], | |
| "repair_strategy": repair_strategy, | |
| "reason": f"paired evaluation produced explicit chosen/rejected failures; optimize the {repair_strategy} lane before paid SFT breakout retries", | |
| }, | |
| "created_at": utc_now(), | |
| } | |
| write_json(manifest_path, result) | |
| write_markdown_summary(markdown_path, result) | |
| return result | |
Xet Storage Details
- Size:
- 34.4 kB
- Xet hash:
- 034a1b8abb8173eef197db0ffe4a616c9c9dfa0ae17dcf52d3a9e48afd800cae
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.