idacy's picture
Handle no-active-allocation counterfactuals
ce60faf verified
Raw
History Blame Contribute Delete
7.13 kB
"""Smoke checks for the live sklearn inference service.
Run from the repository root:
python -m src.datacenter_verification_api.smoke_test
"""
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
from src.datacenter_verification_api.model_service import (
DEFAULT_MODEL_RUN_DIR,
KNOWN_NON_MODEL_METADATA_FIELDS,
REPO_ROOT,
ModelService,
resolve_repo_path,
)
from src.datacenter_verification_api.schemas import PredictRequest
from src.datacenter_verification_modeling.common import LABELS
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model-run", type=Path, default=DEFAULT_MODEL_RUN_DIR)
parser.add_argument("--tolerance", type=float, default=1e-8)
return parser.parse_args()
def selected_prediction_rows(predictions: pd.DataFrame) -> pd.DataFrame:
rows = []
for label in LABELS:
part = predictions[predictions["predicted_label"].astype(int) == label]
if len(part):
rows.append(part.iloc[0])
if len(rows) < len(LABELS):
for _, row in predictions.head(len(LABELS) - len(rows)).iterrows():
rows.append(row)
return pd.DataFrame(rows).drop_duplicates("feature_row_id")
def main() -> int:
args = parse_args()
model_run_dir = resolve_repo_path(args.model_run, DEFAULT_MODEL_RUN_DIR)
service = ModelService(model_run_dir=model_run_dir, feature_table_path=REPO_ROOT / "data/synthetic_v1/features/window_features_all.csv")
predictions_path = model_run_dir / "predictions_all.csv"
predictions = pd.read_csv(predictions_path)
sample = selected_prediction_rows(predictions)
checked = 0
for _, expected in sample.iterrows():
row_id = str(expected["feature_row_id"])
actual = service.predict(PredictRequest(feature_row_id=row_id))
if actual.predicted_label != int(expected["predicted_label"]):
raise AssertionError(
f"{row_id}: predicted label mismatch {actual.predicted_label} != {int(expected['predicted_label'])}"
)
for label in LABELS:
expected_probability = float(expected[f"p_label_{label}"])
actual_probability = actual.probability_by_label[str(label)]
delta = abs(actual_probability - expected_probability)
if delta > args.tolerance:
raise AssertionError(
f"{row_id}: p_label_{label} mismatch {actual_probability} != {expected_probability} "
f"(delta={delta})"
)
checked += 1
edit_row_id = str(sample.iloc[-1]["feature_row_id"])
edited = service.predict(
PredictRequest(
feature_row_id=edit_row_id,
features={
"o2_max_concurrent_normalized_gpus": 1024,
"o2_allocation_duration_hours": 48,
"o4_gpu_util_p95": 82,
},
context={"scope_type": "topology_domain", "window_length_seconds": 3600},
derive=True,
return_completed_features=True,
)
)
if len(edited.probabilities) != len(LABELS):
raise AssertionError("edited prediction did not return five label probabilities")
if abs(sum(edited.probabilities) - 1.0) > 1e-8:
raise AssertionError("edited prediction probabilities do not sum to 1")
if not edited.completed_features:
raise AssertionError("edited prediction did not return completed_features")
target_row_id = "feat_455d59646b2f3bc099ffd959"
if target_row_id not in service.feature_lookup:
target_row_id = edit_row_id
base_row = service.feature_lookup[target_row_id]
full_ui_features = {column: base_row.get(column) for column in service.feature_columns}
for column in KNOWN_NON_MODEL_METADATA_FIELDS:
full_ui_features[column] = base_row.get(column, False)
full_ui_features["o4_sm_tensor_active_p95"] = 10
tensor_edit = service.predict(
PredictRequest(
feature_row_id=target_row_id,
features=full_ui_features,
context={
"scope_type": base_row.get("scope_type") or "topology_domain",
"window_length_seconds": base_row.get("window_length_seconds") or 3600,
},
derive=True,
return_completed_features=True,
)
)
if tensor_edit.input_warnings:
raise AssertionError(f"UI-style metadata payload produced user-facing warnings: {tensor_edit.input_warnings}")
completed_tensor = tensor_edit.completed_features.get("o4_sm_tensor_active_p95")
if completed_tensor != 10:
raise AssertionError(f"tensor activity edit was not preserved; completed value is {completed_tensor!r}")
zero_duration_features = dict(full_ui_features)
zero_duration_features["o2_allocation_duration_hours"] = 0
zero_duration = service.predict(
PredictRequest(
feature_row_id=target_row_id,
features=zero_duration_features,
context={
"scope_type": base_row.get("scope_type") or "topology_domain",
"window_length_seconds": base_row.get("window_length_seconds") or 3600,
},
derive=True,
return_completed_features=True,
)
)
if zero_duration.predicted_label != 0:
raise AssertionError(f"zero-duration edit did not return L0; got L{zero_duration.predicted_label}")
if zero_duration.probability_by_label["0"] <= 0.9:
raise AssertionError(f"zero-duration edit did not strongly prefer L0: {zero_duration.probability_by_label}")
completed_zero = zero_duration.completed_features
expected_zero_fields = [
"o2_max_concurrent_normalized_gpus",
"o2_allocation_duration_hours",
"o4_gpu_util_p95",
"o4_sm_tensor_active_p95",
"o7_synchronized_fabric_footprint",
"o11_checkpoint_periodicity_score",
]
nonzero_fields = [field for field in expected_zero_fields if completed_zero.get(field) != 0]
if nonzero_fields:
raise AssertionError(f"zero-duration edit left active evidence in completed features: {nonzero_fields}")
if completed_zero.get("o12_signed_ml_logs_present") is not False:
raise AssertionError("zero-duration edit did not clear signed ML logs")
print(f"PASS live inference smoke test: matched {checked} exported rows from {predictions_path}")
print(
"PASS edited prediction schema: "
f"row={edit_row_id} label={edited.predicted_label} p_large={edited.p_large_training:.6f}"
)
print(
"PASS UI payload regression: "
f"row={target_row_id} tensor={tensor_edit.completed_features.get('o4_sm_tensor_active_p95')} "
f"debug_warnings={len(tensor_edit.debug_warnings)}"
)
print(
"PASS zero-duration counterfactual: "
f"row={target_row_id} label={zero_duration.predicted_label} "
f"p0={zero_duration.probability_by_label['0']:.6f}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())