idacy's picture
Handle no-active-allocation counterfactuals
ce60faf verified
Raw
History Blame Contribute Delete
13.1 kB
"""Sklearn model loading and one-row inference for the live API."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pandas as pd
from src.datacenter_verification_api import __version__, build_info
from src.datacenter_verification_api.feature_completion import (
FeatureSchema,
complete_features,
jsonable,
normalize_mapping,
number_or_none,
)
from src.datacenter_verification_api.schemas import MetadataResponse, PredictRequest, PredictResponse
from src.datacenter_verification_modeling.common import (
LABELS,
PROB_COLUMNS,
add_governance_outputs,
model_input_frame,
probability_frame,
read_json,
)
REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_MODEL_RUN_DIR = Path("data/model_runs/synthetic_v1_baseline")
DEFAULT_FEATURE_TABLE = Path("data/synthetic_v1/features/window_features_all.csv")
KNOWN_NON_MODEL_METADATA_FIELDS = {
"capacity_evidence_only",
"integrity_evidence_only",
"physical_evidence_only",
}
NO_ACTIVE_ALLOCATION_PROBABILITIES = {
0: 0.97,
1: 0.02,
2: 0.006,
3: 0.002,
4: 0.002,
}
NO_ACTIVE_ALLOCATION_KEYS = {
"o2_max_concurrent_normalized_gpus",
"o2_allocation_duration_hours",
}
def resolve_repo_path(value: str | Path | None, default: Path) -> Path:
path = Path(value) if value else default
if path.is_absolute():
return path
return (REPO_ROOT / path).resolve()
def repo_relative(path: Path | None) -> str | None:
if path is None:
return None
try:
return path.resolve().relative_to(REPO_ROOT).as_posix()
except ValueError:
return path.name
def read_json_if_exists(path: Path) -> Any:
return read_json(path) if path.exists() else {}
def split_semicolon(value: Any) -> list[str]:
if value is None:
return []
try:
if pd.isna(value):
return []
except (TypeError, ValueError):
pass
return [item.strip() for item in str(value).split(";") if item.strip()]
class ModelService:
"""Loaded model artifacts plus base-row lookup for prediction requests."""
def __init__(self, model_run_dir: Path, feature_table_path: Path | None = None) -> None:
self.model_run_dir = model_run_dir
self.feature_table_path = feature_table_path
self.manifest = read_json_if_exists(model_run_dir / "manifest.json")
self.metrics = read_json_if_exists(model_run_dir / "metrics.json")
self.feature_columns: list[str] = read_json(model_run_dir / "feature_columns.json")
self.preprocessor = joblib.load(model_run_dir / "preprocessing.joblib")
self.model = joblib.load(model_run_dir / "model.joblib")
self.feature_table = self._load_feature_table(feature_table_path)
self.feature_schema = FeatureSchema.from_frame(self.feature_columns, self.feature_table)
self.feature_lookup = self._build_feature_lookup(self.feature_table)
self.model_run_id = str(self.manifest.get("model_run_id") or model_run_dir.name)
@classmethod
def from_env(cls) -> "ModelService":
model_run_dir = resolve_repo_path(os.getenv("DCV_MODEL_RUN_DIR"), DEFAULT_MODEL_RUN_DIR)
feature_table = resolve_repo_path(os.getenv("DCV_FEATURE_TABLE"), DEFAULT_FEATURE_TABLE)
return cls(model_run_dir=model_run_dir, feature_table_path=feature_table)
def _load_feature_table(self, path: Path | None) -> pd.DataFrame | None:
if path is None or not path.exists():
return None
frame = pd.read_csv(path)
if "feature_row_id" not in frame.columns:
raise ValueError(f"feature table lacks feature_row_id column: {path}")
return frame
def _build_feature_lookup(self, frame: pd.DataFrame | None) -> dict[str, dict[str, Any]]:
if frame is None:
return {}
rows: dict[str, dict[str, Any]] = {}
for record in frame.to_dict(orient="records"):
row_id = record.get("feature_row_id")
if row_id is not None:
rows[str(row_id)] = dict(record)
return rows
@property
def dataset_id(self) -> str | None:
training_metadata = self.manifest.get("training_metadata", {})
if isinstance(training_metadata, dict):
path = training_metadata.get("features_path") or self.manifest.get("features_path")
else:
path = self.manifest.get("features_path")
if isinstance(path, str) and "synthetic_v1" in path:
return "synthetic_v1"
if self.feature_table is not None and "dataset_id" in self.feature_table.columns:
values = self.feature_table["dataset_id"].dropna().unique()
if len(values):
return str(values[0])
return None
@property
def dataset_scale(self) -> str | None:
if self.feature_table is not None and "dataset_id" in self.feature_table.columns:
values = self.feature_table["dataset_id"].dropna().unique()
if len(values):
return str(values[0])
return None
@property
def model_type(self) -> str | None:
value = self.manifest.get("model_type")
if isinstance(value, str):
return value
training_metadata = self.manifest.get("training_metadata", {})
if isinstance(training_metadata, dict):
inner = training_metadata.get("model_type")
if isinstance(inner, str):
return inner
return type(self.model).__name__
def metadata(self) -> MetadataResponse:
metrics_summary = self.manifest.get("metrics_summary")
if not isinstance(metrics_summary, dict):
metrics_summary = {
"model": self.metrics.get("model", {}),
"governance": self.metrics.get("governance", {}),
"calibration": self.metrics.get("calibration", {}),
}
build = build_info()
return MetadataResponse(
api_version=__version__,
build_sha=build.sha,
build_source=build.source,
model_run_id=self.model_run_id,
model_run_dir=repo_relative(self.model_run_dir) or self.model_run_dir.name,
feature_table=repo_relative(self.feature_table_path),
dataset_id=self.dataset_id,
dataset_scale=self.dataset_scale,
model_type=self.model_type,
metrics_summary=metrics_summary,
feature_count=len(self.feature_columns),
feature_columns=self.feature_columns,
supported_labels=LABELS,
base_row_lookup_enabled=bool(self.feature_lookup),
)
def build_feature_row(self, request: PredictRequest) -> tuple[dict[str, Any], list[str], list[str]]:
warnings: list[str] = []
debug_warnings: list[str] = []
base_row = None
if request.feature_row_id:
base_row = self.feature_lookup.get(request.feature_row_id)
if base_row is None:
warnings.append(
"The selected datapoint was not found in the live API reference table. "
"Live scoring may be less reliable because many model inputs may be missing."
)
row: dict[str, Any]
has_base_row = base_row is not None
if base_row is not None:
row = dict(base_row)
else:
row = {column: None for column in self.feature_columns}
if request.feature_row_id:
row["feature_row_id"] = request.feature_row_id
context = normalize_mapping(request.context, self.feature_schema)
features = normalize_mapping(request.features, self.feature_schema)
changed_keys: set[str] = set()
feature_keys: set[str] = set(features)
for source_name, values in [("context", context), ("features", features)]:
for key, value in values.items():
row[key] = value
changed_keys.add(key)
if key in {"feature_row_id", *KNOWN_NON_MODEL_METADATA_FIELDS}:
if key in KNOWN_NON_MODEL_METADATA_FIELDS:
debug_warnings.append(f"{source_name} field is metadata-only and was not sent to the model: {key}")
elif key not in self.feature_columns:
warnings.append(f"The live API ignored an unrecognized input field: {key}")
for column in self.feature_columns:
row.setdefault(column, None)
row, completion_warnings = complete_features(
row,
changed_keys,
has_base_row=has_base_row,
derive=request.derive,
edited_feature_keys=feature_keys,
)
warnings.extend(completion_warnings)
if not has_base_row:
null_count = sum(1 for column in self.feature_columns if row.get(column) is None)
if null_count > len(self.feature_columns) // 3:
warnings.append(
f"The live API is missing {null_count} of {len(self.feature_columns)} model inputs. "
"Choose a sampled datapoint before editing controls for a more reliable live score."
)
return row, warnings, debug_warnings
def no_active_allocation_counterfactual(self, request: PredictRequest, row: dict[str, Any]) -> bool:
if not request.derive:
return False
if not (NO_ACTIVE_ALLOCATION_KEYS & set(request.features)):
return False
allocation = number_or_none(row.get("o2_max_concurrent_normalized_gpus"))
duration = number_or_none(row.get("o2_allocation_duration_hours"))
return allocation is not None and allocation <= 0 and duration is not None and duration <= 0
def predict(self, request: PredictRequest) -> PredictResponse:
row, warnings, debug_warnings = self.build_feature_row(request)
record = {
column: (np.nan if row.get(column) is None else row.get(column))
for column in sorted(set(row) | set(self.feature_columns))
}
frame = pd.DataFrame([record])
model_frame = model_input_frame(frame, self.feature_columns)
transformed = self.preprocessor.transform(model_frame)
raw_probabilities = probability_frame(self.model, transformed)
governance = add_governance_outputs(frame, raw_probabilities)
raw = raw_probabilities.iloc[0]
post = governance.iloc[0]
probabilities = [float(post[f"p_label_{label}"]) for label in LABELS]
predicted_label = int(post["predicted_label"])
p_large_training = float(post["p_large_training"])
severity_score = float(post["severity_score"])
negative_certification_confidence = float(post["negative_certification_confidence"])
integrity_warning = bool(post["integrity_warning"])
top_evidence = split_semicolon(post["top_evidence"])
if self.no_active_allocation_counterfactual(request, row):
probabilities = [float(NO_ACTIVE_ALLOCATION_PROBABILITIES[label]) for label in LABELS]
predicted_label = 0
p_large_training = float(NO_ACTIVE_ALLOCATION_PROBABILITIES[3] + NO_ACTIVE_ALLOCATION_PROBABILITIES[4])
severity_score = float(sum(label * NO_ACTIVE_ALLOCATION_PROBABILITIES[label] for label in LABELS))
negative_certification_confidence = float(
NO_ACTIVE_ALLOCATION_PROBABILITIES[0] * float(post["min_critical_coverage"])
)
integrity_warning = False
top_evidence = ["no active allocation", "no strong positive evidence"]
debug_warnings.append("applied no-active-allocation counterfactual override")
completed_features = {
column: jsonable(row.get(column))
for column in self.feature_columns
if request.return_completed_features
}
return PredictResponse(
model_run_id=self.model_run_id,
feature_row_id=request.feature_row_id or jsonable(row.get("feature_row_id")),
predicted_label=predicted_label,
p_large_training=p_large_training,
severity_score=severity_score,
negative_certification_confidence=negative_certification_confidence,
integrity_warning=integrity_warning,
capacity_possible=bool(post["capacity_possible"]),
min_critical_coverage=float(post["min_critical_coverage"]),
probabilities=probabilities,
probability_by_label={str(label): float(probabilities[index]) for index, label in enumerate(LABELS)},
raw_probability_by_label={str(label): float(raw[f"p_label_{label}"]) for label in LABELS},
top_evidence=top_evidence,
critical_missing_layers=split_semicolon(post["critical_missing_layers"]),
input_warnings=warnings,
debug_warnings=debug_warnings,
completed_features=completed_features,
)