linvest21/shft-artifacts / code /self_healing_finetuning /tests /test_pairwise_preference_memory.py
linvest21's picture
download
raw
15 kB
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from data_pipeline import pairwise_preference_memory as preference_memory
from data_pipeline.pairwise_preference_memory import build_pairwise_preference_data
def write_jsonl(path: Path, rows: list[dict[str, object]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(json.dumps(row) for row in rows) + "\n", encoding="utf-8")
class PairwisePreferenceMemoryTests(unittest.TestCase):
def test_builds_preference_pairs_from_pairwise_losses_and_critical_failures(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
workspace = Path(tmp)
run_id = "run_test_pairwise_memory"
eval_dir = workspace / "runs" / run_id / "eval"
predictions = eval_dir / "paired_predictions.jsonl"
write_jsonl(
predictions,
[
{
"id": "loss_1",
"task": "quantitative_qa",
"prompt": "Revenue grew from 100 to 110 while margin fell from 20% to 15%. What matters?",
"baseline_answer": "<think>calculate ratios</think> Revenue rose from 100 to 110, but margin fell from 20% to 15%, an offsetting risk that pressures profit even as the top line grows.",
"candidate_answer": "<think>ignore offsetting risk</think> This is clearly bullish because revenue rose.",
"baseline_score": {"score": 0.9, "critical_pass": True},
"candidate_score": {"score": 0.4, "critical_pass": True},
"delta": -0.5,
},
{
"id": "critical_1",
"task": "finance_qa",
"prompt": "A company reports one quarter of better free cash flow. Should rating change?",
"candidate_answer": "Upgrade immediately based on one quarter.",
"candidate_score": {"score": 0.6, "critical_pass": False},
"delta": 0.0,
},
{
"id": "win_1",
"task": "finance_qa",
"prompt": "What is a balanced answer?",
"baseline_answer": "Weak answer.",
"candidate_answer": "Strong answer.",
"baseline_score": {"score": 0.5, "critical_pass": True},
"candidate_score": {"score": 0.8, "critical_pass": True},
"delta": 0.3,
},
],
)
(eval_dir / "paired_eval_report.json").write_text(
json.dumps({"improvement": {"losses": 1, "pairwise_loss_rate": 0.3333}}),
encoding="utf-8",
)
with mock.patch.object(preference_memory, "SHFT_WORKSPACE_ROOT", workspace):
result = build_pairwise_preference_data(
run_id=run_id,
asset_class="equity",
role="researcher",
max_records=10,
)
self.assertTrue(result["ok"], result)
self.assertEqual(result["summary"]["preference_pair_count"], 2)
self.assertEqual(result["summary"]["pairwise_loss_pair_count"], 1)
self.assertEqual(result["summary"]["critical_failure_pair_count"], 1)
output = Path(result["output_path"])
pairs = [json.loads(line) for line in output.read_text(encoding="utf-8").splitlines()]
self.assertEqual(pairs[0]["prompt"], "Revenue grew from 100 to 110 while margin fell from 20% to 15%. What matters?")
self.assertEqual(
pairs[0]["chosen"],
"Revenue rose from 100 to 110, but margin fell from 20% to 15%, an offsetting risk that pressures profit even as the top line grows.",
)
self.assertEqual(pairs[0]["rejected"], "This is clearly bullish because revenue rose.")
self.assertNotIn("<think>", pairs[0]["chosen"])
self.assertNotIn("<think>", pairs[0]["rejected"])
self.assertTrue(pairs[0]["metadata"]["pairwise_loss"])
# Baseline genuinely beat the candidate here, so chosen=baseline is a
# legitimate "match the better answer" target, not a parity-cap fallback.
self.assertEqual(pairs[0]["metadata"]["chosen_source"], "baseline_winning_answer")
self.assertIn("failure_bucket", pairs[0]["metadata"])
self.assertIn(pairs[0]["metadata"]["failure_bucket"], preference_memory.DIAGNOSTIC_BUCKETS)
self.assertTrue(pairs[0]["metadata"]["judge_rationale"])
self.assertTrue(pairs[0]["metadata"]["repair_target"]["admitted_to_training"])
self.assertNotIn("<think>", pairs[0]["metadata"]["repair_target"]["answer"])
self.assertTrue(pairs[1]["metadata"]["critical_failure"])
self.assertNotEqual(pairs[1]["chosen"], pairs[1]["rejected"])
self.assertTrue(Path(result["manifest_path"]).exists())
self.assertTrue(Path(result["markdown_path"]).exists())
def test_losses_only_skips_non_loss_critical_failures(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
workspace = Path(tmp)
run_id = "run_test_losses_only"
predictions = workspace / "runs" / run_id / "eval" / "paired_predictions.jsonl"
write_jsonl(
predictions,
[
{
"id": "critical_only",
"prompt": "One good quarter with weak balance sheet.",
"candidate_answer": "No risk exists.",
"candidate_score": {"score": 0.5, "critical_pass": False},
"delta": 0.0,
}
],
)
with mock.patch.object(preference_memory, "SHFT_WORKSPACE_ROOT", workspace):
result = build_pairwise_preference_data(
run_id=run_id,
asset_class="equity",
role="researcher",
include_critical_failures=False,
)
self.assertFalse(result["ok"])
self.assertEqual(result["summary"]["preference_pair_count"], 0)
self.assertEqual(result["summary"]["skipped"]["not_loss_or_critical_failure"], 1)
def test_explicit_gold_answer_breaks_baseline_parity_trap(self) -> None:
prediction = {
"id": "gold_1",
"task": "finance_qa",
"prompt": "Summarize the revenue risk.",
"baseline_answer": "Revenue grew, which is good.",
"candidate_answer": "Revenue grew, so buy.",
"gold_answer": "Reported revenue grew, but the inference of durable growth is unsupported; the key risk is that backlog declined, which often precedes a revenue slowdown.",
"baseline_score": {"score": 0.7, "critical_pass": True},
"candidate_score": {"score": 0.4, "critical_pass": True},
"delta": -0.3,
}
chosen, source = preference_memory.corrected_chosen_answer(
prediction=prediction, asset_class="equity", role="researcher", defects=["pairwise_loss"]
)
self.assertEqual(source, "gold_answer")
self.assertIn("backlog declined", chosen)
self.assertNotEqual(chosen, prediction["baseline_answer"])
def test_tie_without_gold_synthesizes_instead_of_capping_at_baseline(self) -> None:
prediction = {
"id": "tie_1",
"task": "finance_qa",
"prompt": "One good quarter of free cash flow; should the rating change?",
"baseline_answer": "Maybe.",
"candidate_answer": "Yes, upgrade now.",
"candidate_score": {"score": 0.5, "critical_pass": False},
"delta": 0.0,
}
chosen, source = preference_memory.corrected_chosen_answer(
prediction=prediction, asset_class="equity", role="researcher", defects=["critical_failure"]
)
self.assertEqual(source, "rubric_grounded_synthetic")
self.assertNotEqual(chosen, prediction["baseline_answer"])
def test_quant_chosen_without_numbers_is_not_admitted_to_training(self) -> None:
checks = preference_memory.repair_acceptance_checks(
prompt="Revenue grew from 100 to 110 while margin fell from 20% to 15%. What matters?",
chosen="The favorable item should be weighed against the adverse item before changing a view here.",
rejected="Clearly bullish.",
task="quantitative_qa",
)
self.assertFalse(checks["numeric_answer_for_numeric_prompt"])
self.assertFalse(preference_memory.repair_admitted_to_training(checks))
def test_bucket_weighted_selection_keeps_coverage_under_tight_cap(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
workspace = Path(tmp)
run_id = "run_test_bucket_weighted"
predictions = workspace / "runs" / run_id / "eval" / "paired_predictions.jsonl"
rows: list[dict[str, object]] = []
# 6 losses in one bucket (valuation_math) followed by 1 in another
# (moat_reasoning). First-come truncation at 2 would drop the moat bucket.
for i in range(6):
rows.append(
{
"id": f"val_{i}",
"task": "finance_qa",
"prompt": "Discuss the valuation multiple and growth ratio implications.",
"baseline_answer": "Baseline weighs the multiple against growth and flags the key risk.",
"candidate_answer": "It is clearly cheap.",
"baseline_score": {"score": 0.9, "critical_pass": True},
"candidate_score": {"score": 0.3, "critical_pass": True},
"delta": -0.6,
}
)
rows.append(
{
"id": "moat_1",
"task": "finance_qa",
"prompt": "Assess the competitive moat and pricing power durability.",
"baseline_answer": "Baseline weighs switching costs against competitive risk and stays neutral.",
"candidate_answer": "The moat is unbreakable.",
"baseline_score": {"score": 0.9, "critical_pass": True},
"candidate_score": {"score": 0.3, "critical_pass": True},
"delta": -0.6,
}
)
write_jsonl(predictions, rows)
with mock.patch.object(preference_memory, "SHFT_WORKSPACE_ROOT", workspace):
result = build_pairwise_preference_data(
run_id=run_id,
asset_class="equity",
role="researcher",
max_records=2,
)
buckets = result["summary"]["bucket_selection_counts"]
self.assertEqual(result["summary"]["preference_pair_count"], 2)
self.assertIn("moat_reasoning", buckets)
self.assertEqual(result["summary"]["selection_strategy"], "generic_loss_targeted_bucket_weighted_round_robin")
def test_critical_safety_strategy_excludes_noncritical_pairwise_losses(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
workspace = Path(tmp)
run_id = "run_test_critical_strategy"
predictions = workspace / "runs" / run_id / "eval" / "paired_predictions.jsonl"
write_jsonl(
predictions,
[
{
"id": "ordinary_loss",
"task": "finance_qa",
"prompt": "Assess the valuation multiple and growth risk.",
"baseline_answer": "The baseline separates valuation risk from growth facts and stays neutral.",
"candidate_answer": "It is obviously cheap.",
"baseline_score": {"score": 0.9, "critical_pass": True},
"candidate_score": {"score": 0.3, "critical_pass": True},
"delta": -0.6,
},
{
"id": "critical_regression",
"task": "finance_qa",
"prompt": "One good quarter of free cash flow; should the rating change?",
"baseline_answer": "No rating change is justified from one quarter without confirming recurring cash conversion and balance-sheet risk.",
"candidate_answer": "Upgrade immediately.",
"baseline_score": {"score": 0.8, "critical_pass": True},
"candidate_score": {"score": 0.6, "critical_pass": False},
"delta": 0.0,
},
],
)
with mock.patch.object(preference_memory, "SHFT_WORKSPACE_ROOT", workspace):
result = build_pairwise_preference_data(
run_id=run_id,
asset_class="equity",
role="researcher",
max_records=10,
repair_strategy="critical_safety_repair",
)
self.assertTrue(result["ok"], result)
self.assertEqual(result["summary"]["repair_strategy"], "critical_safety_repair")
self.assertEqual(result["summary"]["preference_pair_count"], 1)
self.assertEqual(result["summary"]["critical_failure_pair_count"], 1)
self.assertEqual(result["summary"]["pairwise_loss_pair_count"], 0)
self.assertEqual(result["summary"]["skipped"]["strategy_filtered_critical_safety_repair"], 1)
pair = json.loads(Path(result["output_path"]).read_text(encoding="utf-8").splitlines()[0])
self.assertEqual(pair["metadata"]["source_prediction_id"], "critical_regression")
self.assertEqual(pair["metadata"]["repair_strategy"], "critical_safety_repair")
def test_preamble_chosen_is_not_admitted_to_training(self) -> None:
checks = preference_memory.repair_acceptance_checks(
prompt="Summarize the revenue risk.",
chosen="Sure, here is a summary of the revenue risk you asked about for the company.",
rejected="Bullish.",
task="finance_qa",
)
self.assertFalse(checks["no_preamble"])
self.assertFalse(preference_memory.repair_admitted_to_training(checks))
if __name__ == "__main__":
unittest.main()

Xet Storage Details

Size:
15 kB
·
Xet hash:
26685d899b730c1a33f7377a6ae1b9c4b58c16d5279550d61080d031d2c2a256

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.