idacy's picture
Handle no-active-allocation counterfactuals
ce60faf verified
Raw
History Blame Contribute Delete
16 kB
"""Feature normalization and dependent-feature completion for live inference."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import math
import numpy as np
import pandas as pd
NULL_STRINGS = {"", "nan", "none", "null", "undefined", "na", "n/a", "<na>"}
FRACTION_COLUMNS = {
"o1_homogeneous_high_end_fraction",
"o1_non_partitioned_fraction",
"o2_concurrency_fraction_domain",
"o3_training_sku_fraction",
"o3_billing_continuity_score",
"o4_gpu_util_duty_gt_70",
"o4_hbm_used_fraction_p50",
"o4_hbm_bandwidth_active_p95",
"o4_gpu_power_fraction_p95",
"o4_hbm_pressure_duration_fraction",
"o4_power_cap_active_fraction",
"o4_thermal_throttle_fraction",
"o5_kernel_training_motif_score",
"o5_tensor_throughput_ratio",
"o6_nvlink_util_p95",
"o6_nvlink_periodicity_score",
"o7_scaleout_port_util_p95",
"o7_collective_periodicity_score",
"o7_burst_duty_cycle",
"o7_job_to_port_mapping_coverage",
"o7_flow_entropy_score",
"o7_cross_section_sync_score",
"o7_collective_jitter_score",
"o7_storage_traffic_fraction",
"o7_inference_fanout_score",
"o7_account_flow_linkage_confidence",
"o8_rack_power_fraction_p95",
"o8_power_cv",
"o8_power_cap_or_curtailment_active",
"o8_unattributed_power_fraction",
"o9_cooling_flow_duty",
"o10_rank_stability_score",
"o10_runtime_metadata_confidence",
"o11_checkpoint_periodicity_score",
"o11_read_write_training_pattern_score",
"o11_checkpoint_jitter_score",
"o11_artifact_write_pattern_score",
"o11_dataloader_read_pattern_score",
"o11_backup_or_replication_pattern_score",
"o11_storage_cotraffic_score",
"o12_log_completeness_fraction",
"o12_declaration_consistency_score",
"o13_attestation_valid_fraction",
"o13_confidential_compute_mode_fraction",
"o13_collector_measurement_valid",
"o14_min_critical_coverage",
"o14_gap_fraction_critical",
"o16_probe_throughput_ratio_min",
"o17_energy_contract_alignment_score",
"o17_network_provider_utilization_score",
"o17_procurement_or_maintenance_explanation_score",
}
for index in range(1, 18):
FRACTION_COLUMNS.add(f"o{index}_coverage_fraction")
NO_ACTIVE_ALLOCATION_TRIGGER_COLUMNS = {
"o2_max_concurrent_normalized_gpus",
"o2_allocation_duration_hours",
}
NO_ACTIVE_NUMERIC_ZERO_COLUMNS = {
"policy_compute_ratio",
"o2_max_concurrent_normalized_gpus",
"o2_allocation_duration_hours",
"o2_gpu_hours_policy_ratio",
"o2_concurrency_fraction_domain",
"o2_elastic_resize_count",
"o2_preemption_restart_count",
"o2_scheduler_queue_delay_hours",
"o2_job_array_width",
"o2_reservation_reuse_count",
"o3_batch_provisioned_gpus",
"o3_capacity_reservation_duration_hours",
"o3_training_sku_fraction",
"o3_billing_continuity_score",
"o3_egress_tb",
"o4_gpu_util_p50",
"o4_gpu_util_p95",
"o4_gpu_util_duty_gt_70",
"o4_sm_tensor_active_p95",
"o4_hbm_used_fraction_p50",
"o4_hbm_bandwidth_active_p95",
"o4_gpu_power_fraction_p95",
"o4_error_spike_score",
"o4_gpu_util_cv",
"o4_hbm_pressure_duration_fraction",
"o4_power_cap_active_fraction",
"o4_thermal_throttle_fraction",
"o5_kernel_training_motif_score",
"o5_tensor_throughput_ratio",
"o6_nvlink_util_p95",
"o6_nvlink_periodicity_score",
"o6_link_error_spike_score",
"o7_scaleout_port_util_p95",
"o7_synchronized_fabric_footprint",
"o7_collective_periodicity_score",
"o7_burst_duty_cycle",
"o7_rdma_congestion_score",
"o7_job_to_port_mapping_coverage",
"o7_flow_entropy_score",
"o7_cross_section_sync_score",
"o7_collective_jitter_score",
"o7_storage_traffic_fraction",
"o7_inference_fanout_score",
"o7_account_flow_linkage_confidence",
"o8_baseline_subtracted_energy_kwh",
"o8_power_cv",
"o8_power_to_gpu_residual",
"o8_power_baseline_drift_score",
"o8_unattributed_power_fraction",
"o9_gpu_hbm_temp_score",
"o9_thermal_delta_t_score",
"o9_cooling_flow_duty",
"o9_thermal_throttle_support_score",
"o10_world_size",
"o10_rank_stability_score",
"o10_same_image_gpu_count",
"o10_declared_vs_observed_mismatch_score",
"o11_data_staging_tb",
"o11_checkpoint_write_tb_per_event",
"o11_checkpoint_periodicity_score",
"o11_read_write_training_pattern_score",
"o11_checkpoint_jitter_score",
"o11_artifact_write_pattern_score",
"o11_dataloader_read_pattern_score",
"o11_backup_or_replication_pattern_score",
"o11_storage_cotraffic_score",
"o12_declared_parameter_count_b",
"o12_training_tokens_b",
"o12_step_count",
"o12_log_delivery_delay_hours",
}
NO_ACTIVE_FALSE_COLUMNS = {
"o2_reservation_exclusive_flag",
"o5_profiler_available",
"o8_power_cap_or_curtailment_active",
"o9_cooling_maintenance_active",
"o10_rendezvous_present",
"o12_signed_ml_logs_present",
"o12_loss_curve_present",
"o12_optimizer_state_present",
}
NO_ACTIVE_FRACTION_ONE_COLUMNS = {
"o10_runtime_metadata_confidence",
"o12_log_completeness_fraction",
"o12_declaration_consistency_score",
}
@dataclass
class FeatureSchema:
"""Column type hints inferred from the training feature table."""
feature_columns: list[str]
numeric_columns: set[str] = field(default_factory=set)
boolean_columns: set[str] = field(default_factory=set)
@classmethod
def from_frame(cls, feature_columns: list[str], frame: pd.DataFrame | None) -> "FeatureSchema":
numeric_columns: set[str] = set()
boolean_columns: set[str] = set()
if frame is not None:
for column in feature_columns:
if column not in frame.columns:
continue
dtype = frame[column].dtype
if pd.api.types.is_bool_dtype(dtype):
boolean_columns.add(column)
numeric_columns.add(column)
elif pd.api.types.is_numeric_dtype(dtype):
numeric_columns.add(column)
return cls(
feature_columns=list(feature_columns),
numeric_columns=numeric_columns,
boolean_columns=boolean_columns,
)
def is_missing(value: Any) -> bool:
if value is None or value is pd.NA:
return True
if isinstance(value, str):
return value.strip().lower() in NULL_STRINGS
if isinstance(value, (float, np.floating)):
return math.isnan(float(value))
try:
return bool(pd.isna(value))
except (TypeError, ValueError):
return False
def jsonable(value: Any) -> Any:
if is_missing(value):
return None
if isinstance(value, (np.integer,)):
return int(value)
if isinstance(value, (np.floating,)):
return float(value)
if isinstance(value, (np.bool_,)):
return bool(value)
if isinstance(value, pd.Timestamp):
return value.isoformat()
return value
def normalize_value(value: Any, column: str | None = None, schema: FeatureSchema | None = None) -> Any:
if is_missing(value):
return None
if isinstance(value, str):
stripped = value.strip()
lowered = stripped.lower()
if lowered in {"true", "t", "yes", "y"}:
return True
if lowered in {"false", "f", "no", "n"}:
return False
if lowered in {"1", "0"} and schema and column in schema.boolean_columns:
return lowered == "1"
if schema and column in schema.numeric_columns:
try:
number = float(stripped)
except ValueError:
return stripped
return int(number) if number.is_integer() else number
return stripped
if schema and column in schema.boolean_columns:
if isinstance(value, (bool, np.bool_)):
return bool(value)
if isinstance(value, (int, float, np.integer, np.floating)) and not is_missing(value):
return bool(value)
if schema and column in schema.numeric_columns:
try:
number = float(value)
except (TypeError, ValueError):
return value
if math.isnan(number):
return None
return int(number) if number.is_integer() else number
return jsonable(value)
def normalize_mapping(values: dict[str, Any] | None, schema: FeatureSchema) -> dict[str, Any]:
if not values:
return {}
return {str(key): normalize_value(value, str(key), schema) for key, value in values.items()}
def number_or_none(value: Any) -> float | None:
if is_missing(value):
return None
try:
number = float(value)
except (TypeError, ValueError):
return None
if math.isnan(number):
return None
return number
def bool_or_false(value: Any) -> bool:
if is_missing(value):
return False
if isinstance(value, (bool, np.bool_)):
return bool(value)
text = str(value).strip().lower()
if text in {"true", "t", "1", "yes", "y"}:
return True
if text in {"false", "f", "0", "no", "n"}:
return False
return False
def set_if_changed(row: dict[str, Any], key: str, value: Any, warnings: list[str], derived: list[str]) -> None:
old = row.get(key)
normalized = jsonable(value)
old_normalized = jsonable(old)
if old_normalized != normalized:
row[key] = normalized
derived.append(key)
def clamp_fraction(value: Any) -> float | None:
number = number_or_none(value)
if number is None:
return None
return min(1.0, max(0.0, number))
def edited_no_active_allocation(row: dict[str, Any], edited_feature_keys: set[str], *, should_derive_all: bool) -> bool:
if not should_derive_all and not (NO_ACTIVE_ALLOCATION_TRIGGER_COLUMNS & edited_feature_keys):
return False
allocation = number_or_none(row.get("o2_max_concurrent_normalized_gpus"))
duration = number_or_none(row.get("o2_allocation_duration_hours"))
return (allocation is not None and allocation <= 0) or (duration is not None and duration <= 0)
def apply_no_active_allocation(row: dict[str, Any], warnings: list[str], derived: list[str]) -> None:
for key in sorted(NO_ACTIVE_NUMERIC_ZERO_COLUMNS):
set_if_changed(row, key, 0, warnings, derived)
for key in sorted(NO_ACTIVE_FALSE_COLUMNS):
set_if_changed(row, key, False, warnings, derived)
for key in sorted(NO_ACTIVE_FRACTION_ONE_COLUMNS):
set_if_changed(row, key, 1, warnings, derived)
set_if_changed(row, "o2_declared_workload_class", "none", warnings, derived)
set_if_changed(row, "o10_runtime_framework_class", "none", warnings, derived)
set_if_changed(row, "o4_gpu_idle_gap_p95_minutes", 60, warnings, derived)
if number_or_none(row.get("o8_rack_power_fraction_p95")) is not None:
set_if_changed(
row,
"o8_rack_power_fraction_p95",
min(number_or_none(row.get("o8_rack_power_fraction_p95")) or 0.0, 0.15),
warnings,
derived,
)
def complete_features(
row: dict[str, Any],
changed_keys: set[str],
*,
has_base_row: bool,
derive: bool,
edited_feature_keys: set[str] | None = None,
) -> tuple[dict[str, Any], list[str]]:
"""Return a completed row and warnings for coherent live inference.
Base-row predictions must remain byte-for-byte close to the exported model
replay, so dependent features are recomputed for base rows only when an
edited input can affect them. Rows without a base use every available input
to derive minimal context.
"""
out = dict(row)
warnings: list[str] = []
derived: list[str] = []
if not has_base_row:
if is_missing(out.get("scope_type")):
out["scope_type"] = "site"
if is_missing(out.get("window_length_seconds")):
out["window_length_seconds"] = 3600
if is_missing(out.get("capacity_possible")):
out["capacity_possible"] = False
if not derive:
return out, warnings
should_derive_all = not has_base_row
edited_keys = set(changed_keys if edited_feature_keys is None else edited_feature_keys)
allocation = number_or_none(out.get("o2_max_concurrent_normalized_gpus"))
duration = number_or_none(out.get("o2_allocation_duration_hours"))
capacity = number_or_none(out.get("o1_normalized_h100e_capacity"))
contiguous = number_or_none(out.get("o1_largest_contiguous_domain_gpus"))
if edited_no_active_allocation(out, edited_keys, should_derive_all=should_derive_all):
apply_no_active_allocation(out, warnings, derived)
allocation = number_or_none(out.get("o2_max_concurrent_normalized_gpus"))
duration = number_or_none(out.get("o2_allocation_duration_hours"))
ratio_inputs_changed = bool(
{"o2_max_concurrent_normalized_gpus", "o2_allocation_duration_hours"} & changed_keys
)
if (should_derive_all or ratio_inputs_changed) and allocation is not None and duration is not None:
ratio = (allocation * duration) / (512 * 24)
set_if_changed(out, "o2_gpu_hours_policy_ratio", ratio, warnings, derived)
set_if_changed(out, "policy_compute_ratio", ratio, warnings, derived)
concurrency_inputs_changed = bool(
{"o2_max_concurrent_normalized_gpus", "o1_normalized_h100e_capacity"} & changed_keys
)
if (should_derive_all or concurrency_inputs_changed) and allocation is not None and capacity is not None:
concurrency = clamp_fraction(allocation / capacity) if capacity > 0 else 0.0
set_if_changed(out, "o2_concurrency_fraction_domain", concurrency, warnings, derived)
if (should_derive_all or "o2_max_concurrent_normalized_gpus" in changed_keys) and allocation is not None:
rounded_allocation = int(round(allocation))
set_if_changed(out, "o10_world_size", rounded_allocation, warnings, derived)
set_if_changed(out, "o10_same_image_gpu_count", rounded_allocation, warnings, derived)
capacity_inputs_changed = bool(
{"o1_normalized_h100e_capacity", "o1_largest_contiguous_domain_gpus"} & changed_keys
)
if should_derive_all or capacity_inputs_changed:
if capacity is not None and contiguous is not None:
set_if_changed(out, "capacity_possible", capacity >= 512 and contiguous >= 512, warnings, derived)
elif not has_base_row:
set_if_changed(out, "capacity_possible", False, warnings, derived)
for key in sorted(FRACTION_COLUMNS & set(out)):
if should_derive_all or key in changed_keys:
clamped = clamp_fraction(out.get(key))
if clamped is not None:
set_if_changed(out, key, clamped, warnings, derived)
allocation = number_or_none(out.get("o2_max_concurrent_normalized_gpus"))
capacity = number_or_none(out.get("o1_normalized_h100e_capacity"))
fabric = number_or_none(out.get("o7_synchronized_fabric_footprint"))
if capacity is not None and allocation is not None and allocation > capacity + 1:
warnings.append(
"The allocated GPU count is higher than this site's monitored GPU capacity. "
"Lower allocated GPUs or choose a higher-capacity site."
)
if capacity is not None and fabric is not None and fabric > capacity + 1:
warnings.append(
"The network fabric size is higher than this site's monitored GPU capacity. "
"Lower fabric footprint or choose a higher-capacity site."
)
if allocation is not None and allocation > 0 and fabric is not None and fabric > allocation + 1:
warnings.append(
"The network fabric size is higher than the allocated GPU count. "
"Lower fabric footprint or raise allocated GPUs."
)
return out, warnings