Update hf_demo.py
Browse files- hf_demo.py +143 -11
hf_demo.py
CHANGED
|
@@ -36,6 +36,9 @@ from infrastructure import (
|
|
| 36 |
RecommendedAction,
|
| 37 |
)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
# ============== CONFIGURATION (Pydantic V2) ==============
|
| 40 |
class Settings(BaseSettings):
|
| 41 |
"""Application settings loaded from environment variables."""
|
|
@@ -257,6 +260,44 @@ class BayesianRiskEngine:
|
|
| 257 |
except sqlite3.Error as e:
|
| 258 |
logger.error(f"Failed to record outcome: {e}")
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
class PolicyEngine:
|
| 261 |
"""Deterministic OSS policies – advisory only."""
|
| 262 |
def __init__(self):
|
|
@@ -399,6 +440,18 @@ class RAGMemory:
|
|
| 399 |
self.embedding_cache[text] = embedding
|
| 400 |
return embedding
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
def _init_db(self):
|
| 403 |
try:
|
| 404 |
with self._get_db() as conn:
|
|
@@ -416,6 +469,15 @@ class RAGMemory:
|
|
| 416 |
embedding TEXT
|
| 417 |
)
|
| 418 |
''')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
conn.execute('''
|
| 420 |
CREATE TABLE IF NOT EXISTS signals (
|
| 421 |
id TEXT PRIMARY KEY,
|
|
@@ -449,17 +511,19 @@ class RAGMemory:
|
|
| 449 |
conn.close()
|
| 450 |
|
| 451 |
def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
|
| 452 |
-
confidence: float, allowed: bool, gates: List[Dict]
|
|
|
|
|
|
|
| 453 |
action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
|
| 454 |
-
# Build a descriptive text and generate embedding
|
| 455 |
incident_text = self._build_incident_text(action)
|
| 456 |
embedding = json.dumps(self._simple_embedding(incident_text))
|
| 457 |
try:
|
| 458 |
with self._get_db() as conn:
|
| 459 |
conn.execute('''
|
| 460 |
INSERT INTO incidents
|
| 461 |
-
(id, action, action_hash, risk_score, risk_level, confidence, allowed, gates, timestamp, embedding
|
| 462 |
-
|
|
|
|
| 463 |
''', (
|
| 464 |
str(uuid.uuid4()),
|
| 465 |
action[:500],
|
|
@@ -470,7 +534,13 @@ class RAGMemory:
|
|
| 470 |
1 if allowed else 0,
|
| 471 |
json.dumps(gates),
|
| 472 |
datetime.utcnow().isoformat(),
|
| 473 |
-
embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
))
|
| 475 |
conn.commit()
|
| 476 |
except sqlite3.Error as e:
|
|
@@ -663,6 +733,23 @@ class InfrastructureEvaluationResponse(BaseModel):
|
|
| 663 |
confidence_score: float
|
| 664 |
evaluation_details: Dict[str, Any]
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
# ============== FASTAPI APP ==============
|
| 667 |
app = FastAPI(
|
| 668 |
title="ARF OSS Real Engine (API Only)",
|
|
@@ -686,6 +773,7 @@ app.add_middleware(
|
|
| 686 |
risk_engine = BayesianRiskEngine()
|
| 687 |
policy_engine = PolicyEngine()
|
| 688 |
memory = RAGMemory()
|
|
|
|
| 689 |
|
| 690 |
# ============== INFRASTRUCTURE SIMULATOR INSTANCE ==============
|
| 691 |
# Corrected: RegionAllowedPolicy expects 'allowed_regions', not 'regions'
|
|
@@ -741,12 +829,28 @@ async def evaluate_action(request: ActionRequest):
|
|
| 741 |
"environment": "production",
|
| 742 |
"user_role": request.user_role,
|
| 743 |
"backup_available": request.rollbackFeasible,
|
| 744 |
-
"requires_human": request.requiresHuman
|
|
|
|
| 745 |
}
|
| 746 |
-
risk
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
policy = policy_engine.evaluate(
|
| 751 |
action=request.proposedAction,
|
| 752 |
risk=risk,
|
|
@@ -754,6 +858,14 @@ async def evaluate_action(request: ActionRequest):
|
|
| 754 |
)
|
| 755 |
similar = memory.find_similar(request.proposedAction, limit=3)
|
| 756 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
if not policy["allowed"] and risk["score"] > 0.7:
|
| 758 |
memory.track_enterprise_signal(
|
| 759 |
signal_type=LeadSignal.HIGH_RISK_BLOCKED,
|
|
@@ -778,7 +890,13 @@ async def evaluate_action(request: ActionRequest):
|
|
| 778 |
risk_level=risk["level"],
|
| 779 |
confidence=request.confidenceScore,
|
| 780 |
allowed=policy["allowed"],
|
| 781 |
-
gates=policy["gates"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
)
|
| 783 |
gates = []
|
| 784 |
for g in policy["gates"]:
|
|
@@ -904,6 +1022,20 @@ async def evaluate_infrastructure_intent(request: InfrastructureIntentRequest):
|
|
| 904 |
logger.error(f"Infrastructure evaluation failed: {e}", exc_info=True)
|
| 905 |
raise HTTPException(500, detail=str(e))
|
| 906 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
# ============== MAIN ENTRY POINT ==============
|
| 908 |
if __name__ == "__main__":
|
| 909 |
import uvicorn
|
|
|
|
| 36 |
RecommendedAction,
|
| 37 |
)
|
| 38 |
|
| 39 |
+
# ============== HMC LEARNER IMPORT ==============
|
| 40 |
+
from hmc_learner import train_hmc_model # new import
|
| 41 |
+
|
| 42 |
# ============== CONFIGURATION (Pydantic V2) ==============
|
| 43 |
class Settings(BaseSettings):
|
| 44 |
"""Application settings loaded from environment variables."""
|
|
|
|
| 260 |
except sqlite3.Error as e:
|
| 261 |
logger.error(f"Failed to record outcome: {e}")
|
| 262 |
|
| 263 |
+
# ---------- NEW: Enhanced risk using HMC coefficients ----------
|
| 264 |
+
def enhanced_risk(self, action_text: str, context: Dict, hmc_coeffs: Optional[Dict] = None) -> float:
|
| 265 |
+
"""
|
| 266 |
+
Compute a risk score using HMC coefficients if available.
|
| 267 |
+
Falls back to simple posterior score if no coefficients.
|
| 268 |
+
"""
|
| 269 |
+
if hmc_coeffs is None:
|
| 270 |
+
return self.calculate_posterior(action_text, context)["score"]
|
| 271 |
+
|
| 272 |
+
# Build feature vector (same as in hmc_learner preprocessing)
|
| 273 |
+
action_cat = self.classify_action(action_text)
|
| 274 |
+
# Map category to code using saved mapping (if present)
|
| 275 |
+
cat_mapping = hmc_coeffs.get("action_cat_mapping", {})
|
| 276 |
+
# Invert mapping (category -> code)
|
| 277 |
+
cat_to_code = {v: k for k, v in cat_mapping.items()}
|
| 278 |
+
cat_code = cat_to_code.get(action_cat, 0) # default to 0 if not found
|
| 279 |
+
|
| 280 |
+
env_prod = 1 if context.get('environment') == 'production' else 0
|
| 281 |
+
role_junior = 1 if context.get('user_role') == 'junior' else 0
|
| 282 |
+
hour = datetime.now().hour
|
| 283 |
+
# Use the simple posterior risk as a feature (normalized)
|
| 284 |
+
simple_risk = self.calculate_posterior(action_text, context)["score"]
|
| 285 |
+
confidence = context.get('confidence', 0.85)
|
| 286 |
+
|
| 287 |
+
# Linear predictor from HMC coefficients
|
| 288 |
+
logit = (
|
| 289 |
+
hmc_coeffs.get('α_cat', {}).get('mean', [0])[cat_code] +
|
| 290 |
+
hmc_coeffs.get('β_env', {}).get('mean', 0) * env_prod +
|
| 291 |
+
hmc_coeffs.get('β_role', {}).get('mean', 0) * role_junior +
|
| 292 |
+
hmc_coeffs.get('β_risk', {}).get('mean', 0) * (simple_risk - 0.5) +
|
| 293 |
+
hmc_coeffs.get('β_hour', {}).get('mean', 0) * ((hour - 12) / 12) +
|
| 294 |
+
hmc_coeffs.get('β_conf', {}).get('mean', 0) * (confidence - 0.5)
|
| 295 |
+
)
|
| 296 |
+
# Convert to probability
|
| 297 |
+
prob = 1 / (1 + np.exp(-logit))
|
| 298 |
+
return prob
|
| 299 |
+
|
| 300 |
+
|
| 301 |
class PolicyEngine:
|
| 302 |
"""Deterministic OSS policies – advisory only."""
|
| 303 |
def __init__(self):
|
|
|
|
| 440 |
self.embedding_cache[text] = embedding
|
| 441 |
return embedding
|
| 442 |
|
| 443 |
+
def _ensure_columns(self, conn, columns):
|
| 444 |
+
"""Add columns to incidents table if they do not exist."""
|
| 445 |
+
cursor = conn.execute("PRAGMA table_info(incidents)")
|
| 446 |
+
existing = [row[1] for row in cursor.fetchall()]
|
| 447 |
+
for col_name, col_type in columns:
|
| 448 |
+
if col_name not in existing:
|
| 449 |
+
try:
|
| 450 |
+
conn.execute(f"ALTER TABLE incidents ADD COLUMN {col_name} {col_type}")
|
| 451 |
+
logger.info(f"Added column {col_name} to incidents table")
|
| 452 |
+
except sqlite3.Error as e:
|
| 453 |
+
logger.error(f"Failed to add column {col_name}: {e}")
|
| 454 |
+
|
| 455 |
def _init_db(self):
|
| 456 |
try:
|
| 457 |
with self._get_db() as conn:
|
|
|
|
| 469 |
embedding TEXT
|
| 470 |
)
|
| 471 |
''')
|
| 472 |
+
# Add new columns if they don't exist
|
| 473 |
+
self._ensure_columns(conn, [
|
| 474 |
+
('environment', 'TEXT'),
|
| 475 |
+
('user_role', 'TEXT'),
|
| 476 |
+
('requires_human', 'BOOLEAN'),
|
| 477 |
+
('rollback_feasible', 'BOOLEAN'),
|
| 478 |
+
('hour_of_day', 'INTEGER'),
|
| 479 |
+
('action_category', 'TEXT')
|
| 480 |
+
])
|
| 481 |
conn.execute('''
|
| 482 |
CREATE TABLE IF NOT EXISTS signals (
|
| 483 |
id TEXT PRIMARY KEY,
|
|
|
|
| 511 |
conn.close()
|
| 512 |
|
| 513 |
def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
|
| 514 |
+
confidence: float, allowed: bool, gates: List[Dict],
|
| 515 |
+
environment: str, user_role: str, requires_human: bool,
|
| 516 |
+
rollback_feasible: bool, hour_of_day: int, action_category: str):
|
| 517 |
action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
|
|
|
|
| 518 |
incident_text = self._build_incident_text(action)
|
| 519 |
embedding = json.dumps(self._simple_embedding(incident_text))
|
| 520 |
try:
|
| 521 |
with self._get_db() as conn:
|
| 522 |
conn.execute('''
|
| 523 |
INSERT INTO incidents
|
| 524 |
+
(id, action, action_hash, risk_score, risk_level, confidence, allowed, gates, timestamp, embedding,
|
| 525 |
+
environment, user_role, requires_human, rollback_feasible, hour_of_day, action_category)
|
| 526 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 527 |
''', (
|
| 528 |
str(uuid.uuid4()),
|
| 529 |
action[:500],
|
|
|
|
| 534 |
1 if allowed else 0,
|
| 535 |
json.dumps(gates),
|
| 536 |
datetime.utcnow().isoformat(),
|
| 537 |
+
embedding,
|
| 538 |
+
environment,
|
| 539 |
+
user_role,
|
| 540 |
+
1 if requires_human else 0,
|
| 541 |
+
1 if rollback_feasible else 0,
|
| 542 |
+
hour_of_day,
|
| 543 |
+
action_category
|
| 544 |
))
|
| 545 |
conn.commit()
|
| 546 |
except sqlite3.Error as e:
|
|
|
|
| 733 |
confidence_score: float
|
| 734 |
evaluation_details: Dict[str, Any]
|
| 735 |
|
| 736 |
+
# ============== GLOBAL HMC MODEL DATA ==============
|
| 737 |
+
hmc_model_data = None
|
| 738 |
+
|
| 739 |
+
def load_hmc_model():
|
| 740 |
+
global hmc_model_data
|
| 741 |
+
model_path = f"{settings.data_dir}/hmc_model.json"
|
| 742 |
+
if os.path.exists(model_path):
|
| 743 |
+
try:
|
| 744 |
+
with open(model_path, 'r') as f:
|
| 745 |
+
hmc_model_data = json.load(f)
|
| 746 |
+
logger.info("HMC model loaded successfully")
|
| 747 |
+
except Exception as e:
|
| 748 |
+
logger.error(f"Failed to load HMC model: {e}")
|
| 749 |
+
hmc_model_data = None
|
| 750 |
+
else:
|
| 751 |
+
logger.info("No HMC model found; using default risk engine")
|
| 752 |
+
|
| 753 |
# ============== FASTAPI APP ==============
|
| 754 |
app = FastAPI(
|
| 755 |
title="ARF OSS Real Engine (API Only)",
|
|
|
|
| 773 |
risk_engine = BayesianRiskEngine()
|
| 774 |
policy_engine = PolicyEngine()
|
| 775 |
memory = RAGMemory()
|
| 776 |
+
load_hmc_model() # Load HMC model after memory init
|
| 777 |
|
| 778 |
# ============== INFRASTRUCTURE SIMULATOR INSTANCE ==============
|
| 779 |
# Corrected: RegionAllowedPolicy expects 'allowed_regions', not 'regions'
|
|
|
|
| 829 |
"environment": "production",
|
| 830 |
"user_role": request.user_role,
|
| 831 |
"backup_available": request.rollbackFeasible,
|
| 832 |
+
"requires_human": request.requiresHuman,
|
| 833 |
+
"confidence": request.confidenceScore # added for enhanced_risk
|
| 834 |
}
|
| 835 |
+
# Use HMC-enhanced risk if available
|
| 836 |
+
if hmc_model_data:
|
| 837 |
+
risk_score_val = risk_engine.enhanced_risk(request.proposedAction, context, hmc_model_data)
|
| 838 |
+
# Convert to a risk dict compatible with policy engine (needs level and interval)
|
| 839 |
+
# For simplicity, reuse the simple engine's level mapping based on enhanced score
|
| 840 |
+
risk = risk_engine.calculate_posterior(request.proposedAction, context)
|
| 841 |
+
risk["score"] = risk_score_val
|
| 842 |
+
if risk_score_val > 0.8:
|
| 843 |
+
risk["level"] = RiskLevel.CRITICAL
|
| 844 |
+
elif risk_score_val > 0.6:
|
| 845 |
+
risk["level"] = RiskLevel.HIGH
|
| 846 |
+
elif risk_score_val > 0.4:
|
| 847 |
+
risk["level"] = RiskLevel.MEDIUM
|
| 848 |
+
else:
|
| 849 |
+
risk["level"] = RiskLevel.LOW
|
| 850 |
+
# Recalculate credible interval? We'll keep the simple one for now.
|
| 851 |
+
else:
|
| 852 |
+
risk = risk_engine.calculate_posterior(request.proposedAction, context)
|
| 853 |
+
|
| 854 |
policy = policy_engine.evaluate(
|
| 855 |
action=request.proposedAction,
|
| 856 |
risk=risk,
|
|
|
|
| 858 |
)
|
| 859 |
similar = memory.find_similar(request.proposedAction, limit=3)
|
| 860 |
|
| 861 |
+
# Capture additional fields for logging
|
| 862 |
+
environment = context["environment"]
|
| 863 |
+
user_role = request.user_role
|
| 864 |
+
requires_human = request.requiresHuman
|
| 865 |
+
rollback_feasible = request.rollbackFeasible
|
| 866 |
+
hour_of_day = datetime.now().hour
|
| 867 |
+
action_category = risk_engine.classify_action(request.proposedAction)
|
| 868 |
+
|
| 869 |
if not policy["allowed"] and risk["score"] > 0.7:
|
| 870 |
memory.track_enterprise_signal(
|
| 871 |
signal_type=LeadSignal.HIGH_RISK_BLOCKED,
|
|
|
|
| 890 |
risk_level=risk["level"],
|
| 891 |
confidence=request.confidenceScore,
|
| 892 |
allowed=policy["allowed"],
|
| 893 |
+
gates=policy["gates"],
|
| 894 |
+
environment=environment,
|
| 895 |
+
user_role=user_role,
|
| 896 |
+
requires_human=requires_human,
|
| 897 |
+
rollback_feasible=rollback_feasible,
|
| 898 |
+
hour_of_day=hour_of_day,
|
| 899 |
+
action_category=action_category
|
| 900 |
)
|
| 901 |
gates = []
|
| 902 |
for g in policy["gates"]:
|
|
|
|
| 1022 |
logger.error(f"Infrastructure evaluation failed: {e}", exc_info=True)
|
| 1023 |
raise HTTPException(500, detail=str(e))
|
| 1024 |
|
| 1025 |
+
# ============== NEW HMC TRAINING ENDPOINT ==============
|
| 1026 |
+
@app.post("/api/v1/admin/train_hmc", dependencies=[Depends(verify_api_key)])
|
| 1027 |
+
async def train_hmc():
|
| 1028 |
+
"""Trigger HMC training on historical incident data."""
|
| 1029 |
+
global hmc_model_data
|
| 1030 |
+
try:
|
| 1031 |
+
db_path = f"{settings.data_dir}/memory.db"
|
| 1032 |
+
model_data = train_hmc_model(db_path, output_dir=settings.data_dir)
|
| 1033 |
+
hmc_model_data = model_data
|
| 1034 |
+
return {"status": "success", "message": "HMC model trained and loaded", "coefficients": model_data.get("coefficients")}
|
| 1035 |
+
except Exception as e:
|
| 1036 |
+
logger.error(f"HMC training failed: {e}", exc_info=True)
|
| 1037 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1038 |
+
|
| 1039 |
# ============== MAIN ENTRY POINT ==============
|
| 1040 |
if __name__ == "__main__":
|
| 1041 |
import uvicorn
|