linvest21's picture
download
raw
15.4 kB
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from eval.best_run_tracker import update_best_run
from eval.model_quality_gate import load_model_quality_thresholds
from n21.config import load_structured, write_json
from n21.settings import CONFIG_ROOT, SHFT_WORKSPACE_ROOT
from observability.audit_log import utc_now
def _load(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8-sig"))
except (OSError, json.JSONDecodeError):
return None
def _num(value: Any) -> float:
try:
return float(value)
except (TypeError, ValueError):
return 0.0
def _env_int(name: str, default: int) -> int:
try:
return int(os.environ.get(name, default))
except (TypeError, ValueError):
return default
def _env_float(name: str, default: float) -> float:
try:
return float(os.environ.get(name, default))
except (TypeError, ValueError):
return default
def _record_count(plan: dict[str, Any] | None) -> int:
if not plan:
return 0
return int(_num(plan.get("validation", {}).get("record_count")))
def _convergence_control(
*,
round_index: int,
breakout: dict[str, Any],
source_batch_acceptance: dict[str, Any],
comparison: dict[str, Any],
) -> dict[str, Any]:
try:
policy = load_structured(CONFIG_ROOT / "data" / "source_policy.yaml")
except Exception:
policy = {}
stall_cfg = policy.get("stall_breakout", {}) if isinstance(policy, dict) else {}
max_rounds = _env_int("SHFT_MAX_BREAKOUT_ROUNDS", int(_num(stall_cfg.get("max_rounds")) or 3))
min_discovery_attempts_before_reasoning_halt = _env_int(
"SHFT_MIN_DISCOVERY_ATTEMPTS_BEFORE_REASONING_HALT",
int(_num(stall_cfg.get("min_discovery_attempts_before_reasoning_halt")) or 2),
)
severe_regression_aggregate_abs = abs(
_env_float(
"SHFT_SEVERE_REGRESSION_AGGREGATE_ABS",
_num(stall_cfg.get("severe_regression_aggregate_abs")) or 0.05,
)
)
severe_regression_critical_abs = abs(
_env_float(
"SHFT_SEVERE_REGRESSION_CRITICAL_ABS",
_num(stall_cfg.get("severe_regression_critical_abs")) or 0.05,
)
)
trainable_new_sources = int(_num(breakout.get("intake", {}).get("trainable_new_source_count")))
live_discovery = breakout.get("live_discovery", {})
discovery_attempts = int(_num(breakout.get("live_discovery", {}).get("attempt_count")))
final_candidate_count_value = live_discovery.get("candidate_count")
final_candidate_count = None if final_candidate_count_value is None else int(_num(final_candidate_count_value))
final_intake_training_eligible_count = int(_num(live_discovery.get("intake_training_eligible_count")))
final_content_ai_rejected_count = int(_num(live_discovery.get("content_ai_rejected_count")))
final_ai_rejected_count = int(_num(live_discovery.get("ai_rejected_count")))
blocked_after_breakout = breakout.get("status") == "blocked_after_breakout"
rejected_batch = source_batch_acceptance.get("decision") == "rejected_did_not_improve_previous_best"
has_previous_best = bool(comparison.get("previous_best_run_id"))
improved = comparison.get("improved_vs_previous_best") is True
aggregate_delta_vs_best = _num(comparison.get("aggregate_delta_vs_previous_best"))
critical_delta_vs_best = _num(
comparison.get("critical_delta_vs_previous_best", comparison.get("critical_pass_delta_vs_previous_best"))
)
rounds_exhausted = round_index >= max_rounds
reasons: list[str] = []
state = "CONTINUE"
action = "continue_measured_loop"
should_halt = False
exit_code = 0
no_source_progress = blocked_after_breakout and trainable_new_sources == 0
regressed_vs_best = rejected_batch and has_previous_best and not improved
severe_regression = regressed_vs_best and (
aggregate_delta_vs_best <= -severe_regression_aggregate_abs
or critical_delta_vs_best <= -severe_regression_critical_abs
)
discovery_exhausted = discovery_attempts >= min_discovery_attempts_before_reasoning_halt
no_candidate_retry_exhausted = (
blocked_after_breakout
and trainable_new_sources == 0
and discovery_attempts > 0
and final_candidate_count == 0
)
no_trainable_candidate_retry_exhausted = (
blocked_after_breakout
and trainable_new_sources == 0
and discovery_attempts > 0
and final_candidate_count is not None
and final_candidate_count > 0
and final_intake_training_eligible_count == 0
and (final_content_ai_rejected_count > 0 or final_ai_rejected_count > 0)
)
if no_source_progress and severe_regression:
state = "NEEDS_REASONING_DATA"
action = "halt_paid_retraining_and_generate_or_supply_reasoning_examples"
should_halt = True
exit_code = 8
reasons.append(
"blocked_after_breakout produced zero trainable sources and the current run materially "
"regressed against the protected best checkpoint"
)
elif no_source_progress and regressed_vs_best and (
discovery_exhausted or no_candidate_retry_exhausted or no_trainable_candidate_retry_exhausted
):
state = "NEEDS_REASONING_DATA"
action = "halt_paid_retraining_and_generate_or_supply_reasoning_examples"
should_halt = True
exit_code = 8
if no_candidate_retry_exhausted and not discovery_exhausted:
reasons.append(
"blocked_after_breakout produced zero trainable sources, "
"the latest live-discovery retry returned zero candidates, "
"and the current run did not improve the protected best checkpoint"
)
elif no_trainable_candidate_retry_exhausted and not discovery_exhausted:
reasons.append(
"blocked_after_breakout produced zero trainable sources, "
"the latest live-discovery retry found candidates but none were training-eligible, "
"and the current run did not improve the protected best checkpoint"
)
else:
reasons.append(
"blocked_after_breakout produced zero trainable sources, "
f"discovery attempted at least {min_discovery_attempts_before_reasoning_halt} time(s), "
"and the current run did not improve the protected best checkpoint"
)
elif blocked_after_breakout and trainable_new_sources == 0 and rounds_exhausted:
state = "NEEDS_REASONING_DATA"
action = "halt_paid_retraining_and_generate_or_supply_reasoning_examples"
should_halt = True
exit_code = 8
reasons.append(
f"blocked_after_breakout produced zero trainable sources after round {round_index} "
f"(max_rounds={max_rounds})"
)
if rejected_batch and has_previous_best and not improved and rounds_exhausted:
state = "NEEDS_REASONING_DATA"
action = "halt_paid_retraining_and_preserve_best_checkpoint"
should_halt = True
exit_code = 8
reasons.append("current source batch did not improve the frozen best measured checkpoint")
return {
"schema_version": "shft_convergence_control_v1",
"state": state,
"action": action,
"should_halt_paid_retraining": should_halt,
"exit_code": exit_code,
"max_rounds": max_rounds,
"min_discovery_attempts_before_reasoning_halt": min_discovery_attempts_before_reasoning_halt,
"severe_regression_thresholds": {
"aggregate_abs": severe_regression_aggregate_abs,
"critical_abs": severe_regression_critical_abs,
},
"aggregate_delta_vs_previous_best": aggregate_delta_vs_best,
"critical_delta_vs_previous_best": critical_delta_vs_best,
"severe_regression": severe_regression,
"round_index": round_index,
"live_discovery_attempt_count": discovery_attempts,
"final_live_discovery_candidate_count": final_candidate_count,
"final_live_discovery_training_eligible_count": final_intake_training_eligible_count,
"final_live_discovery_content_ai_rejected_count": final_content_ai_rejected_count,
"final_live_discovery_ai_rejected_count": final_ai_rejected_count,
"no_candidate_retry_exhausted": no_candidate_retry_exhausted,
"no_trainable_candidate_retry_exhausted": no_trainable_candidate_retry_exhausted,
"best_checkpoint_protected": rejected_batch and has_previous_best and not improved,
"reasoning_data_required": should_halt,
"reasons": reasons,
}
def write_continuous_status(
*,
run_id: str,
release_id: str,
asset_class: str,
role: str,
round_index: int,
phase: str,
) -> dict[str, Any]:
"""Write operator-facing continuous-training status and next-data strategy.
This is deliberately transparent planning. It never marks a model certified
and never fabricates missing judge/human evidence.
"""
run_dir = SHFT_WORKSPACE_ROOT / "runs" / run_id
best = update_best_run(run_id=run_id, release_id=release_id)
gate = _load(run_dir / "eval" / "model_quality_gate.json") or {}
paired = _load(run_dir / "eval" / "paired_eval_report.json") or {}
training = _load(run_dir / "remote_artifacts" / "training_result.json") or {}
training_plan = _load(run_dir / "remote_artifacts" / "training_plan.json") or {}
dataset_manifest = _load(run_dir / "dataset_snapshot" / "dataset_manifest.json") or {}
breakout = _load(run_dir / "stall_breakout" / "stall_breakout_plan.json") or {}
thresholds = load_model_quality_thresholds()
best_run = best.get("best_run") or {}
current = best.get("current_run") or {}
source_batch_acceptance = best.get("source_batch_acceptance") or {}
candidate = paired.get("candidate", {})
improvement = paired.get("improvement", {})
train_records = int(_num(training_plan.get("train_records") or dataset_manifest.get("split_counts", {}).get("train")))
valid_records = int(_num(training_plan.get("valid_records") or dataset_manifest.get("split_counts", {}).get("valid")))
train_loss = training.get("train_loss")
strategy_reasons = []
gate_errors = list(gate.get("errors") or current.get("gate_errors") or [])
error_text = "\n".join(gate_errors).lower()
if "critical_pass" in error_text:
strategy_reasons.append("Add harder role-specific examples with explicit critical pass/fail reasoning and red-flag decisions.")
if "candidate_aggregate_absolute" in error_text:
strategy_reasons.append("Add broader high-quality source material and regenerate training JSONL before the next paid retrain.")
if "model_as_judge" in error_text:
strategy_reasons.append("Produce eval/model_judge_report.json from the configured rubric before certification.")
if "human_spot_check" in error_text:
strategy_reasons.append("Complete human spot-check evidence before promotion.")
if breakout.get("status") == "blocked_after_breakout":
strategy_reasons.append("Continue public-source discovery with fresh retry terms; skip failed, duplicate, and non-trainable sources.")
if not strategy_reasons:
strategy_reasons.append("Continue measured train/eval cycles until all configured gates pass.")
convergence = _convergence_control(
round_index=round_index,
breakout=breakout,
source_batch_acceptance=source_batch_acceptance,
comparison=best.get("previous_best_comparison") or {},
)
if convergence["reasoning_data_required"]:
strategy_reasons.append(
"Generate release-wide paired-eval failure repair and grounded critical-reasoning data before another paid retrain."
)
status = {
"schema_version": "shft_continuous_training_status_v1",
"release_id": release_id,
"run_id": run_id,
"asset_class": asset_class,
"role": role,
"round_index": round_index,
"phase": phase,
"certified": bool(gate.get("ok")),
"current_intelligence": {
"candidate_aggregate": _num(candidate.get("aggregate") or current.get("candidate_aggregate")),
"candidate_critical_pass_rate": _num(candidate.get("critical_pass_rate") or current.get("candidate_critical_pass_rate")),
"aggregate_abs": _num(improvement.get("aggregate_abs") or current.get("aggregate_abs")),
"pairwise_win_rate": _num(improvement.get("pairwise_win_rate") or current.get("pairwise_win_rate")),
"pairwise_loss_rate": _num(improvement.get("pairwise_loss_rate") or current.get("pairwise_loss_rate")),
"train_loss": train_loss,
"train_records": train_records,
"valid_records": valid_records,
},
"best_intelligence": best_run,
"previous_best_comparison": best.get("previous_best_comparison") or {},
"source_batch_acceptance": source_batch_acceptance,
"convergence_control": convergence,
"distance_to_thresholds": current.get("distance_to_thresholds") or best_run.get("distance_to_thresholds") or {},
"quality_gate_errors": gate_errors,
"thresholds": thresholds,
"breakout": {
"status": breakout.get("status"),
"record_count": _record_count(breakout),
"trainable_new_source_count": breakout.get("intake", {}).get("trainable_new_source_count", 0),
"live_discovery_attempt_count": breakout.get("live_discovery", {}).get("attempt_count", 0),
"blockers": breakout.get("blockers", []),
},
"next_data_strategy": {
"objective": "increase measured model quality until the mandatory quality gate passes",
"actions": strategy_reasons,
"source_policy": "download public material automatically; train only policy-approved, AI-certified, normalized training-eligible sources; keep a source batch only if post-training evidence improves the previous best checkpoint",
"escalation": (
f"{asset_class}/{role} needs high-signal critical-reasoning examples"
if convergence["reasoning_data_required"]
else None
),
},
"artifacts": {
"run_dir": str(run_dir),
"best_run_report": str(SHFT_WORKSPACE_ROOT / "best_runs" / f"{release_id}.json"),
"paired_eval_report": str(run_dir / "eval" / "paired_eval_report.json"),
"quality_gate_report": str(run_dir / "eval" / "model_quality_gate.json"),
"stall_breakout_plan": str(run_dir / "stall_breakout" / "stall_breakout_plan.json"),
},
"created_at": utc_now(),
"ok": True,
}
out_dir = SHFT_WORKSPACE_ROOT / "continuous_training"
out_dir.mkdir(parents=True, exist_ok=True)
write_json(out_dir / f"{release_id}_status.json", status)
write_json(run_dir / "continuous_training_status.json", status)
write_json(run_dir / "next_data_strategy.json", status["next_data_strategy"])
return status

Xet Storage Details

Size:
15.4 kB
·
Xet hash:
c688e310ff36b3c3a73642c7c20d50496d668134fd6fc360c49579e779b3af93

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.