idacy's picture
Fix live inference slider regression
c789799 verified
Raw
History Blame Contribute Delete
2.56 kB
"""Pydantic schemas for the live datacenter verification inference API."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class PredictRequest(BaseModel):
"""One-row prediction request.
The feature dictionary intentionally accepts arbitrary keys because the
public demo exposes only a subset of the model columns and future UI
controls may add new fields before the API schema changes.
"""
model_config = ConfigDict(extra="allow")
feature_row_id: str | None = None
features: dict[str, Any] = Field(default_factory=dict)
context: dict[str, Any] = Field(default_factory=dict)
derive: bool = True
return_completed_features: bool = False
class BatchPredictRequest(BaseModel):
requests: list[PredictRequest]
class HealthResponse(BaseModel):
status: str
model_loaded: bool
api_version: str
build_sha: str | None = None
build_source: str | None = None
model_run_id: str | None = None
dataset_id: str | None = None
feature_count: int = 0
base_row_lookup_enabled: bool = False
error: str | None = None
class MetadataResponse(BaseModel):
api_version: str
build_sha: str | None = None
build_source: str | None = None
model_run_id: str
model_run_dir: str
feature_table: str | None = None
dataset_id: str | None = None
dataset_scale: str | None = None
model_type: str | None = None
metrics_summary: dict[str, Any] = Field(default_factory=dict)
feature_count: int
feature_columns: list[str]
supported_labels: list[int]
base_row_lookup_enabled: bool
class PredictResponse(BaseModel):
mode: str = "live_model_inference"
model_run_id: str
feature_row_id: str | None = None
predicted_label: int
p_large_training: float
severity_score: float
negative_certification_confidence: float
integrity_warning: bool
capacity_possible: bool
min_critical_coverage: float
probabilities: list[float]
probability_by_label: dict[str, float]
raw_probability_by_label: dict[str, float] = Field(default_factory=dict)
top_evidence: list[str]
critical_missing_layers: list[str]
input_warnings: list[str] = Field(default_factory=list)
debug_warnings: list[str] = Field(default_factory=list)
completed_features: dict[str, Any] = Field(default_factory=dict)
class BatchPredictResponse(BaseModel):
mode: str = "live_model_inference_batch"
model_run_id: str
predictions: list[PredictResponse]