petter2025 commited on
Commit
695eb99
·
verified ·
1 Parent(s): f836001

Update hf_demo.py

Browse files
Files changed (1) hide show
  1. hf_demo.py +143 -11
hf_demo.py CHANGED
@@ -36,6 +36,9 @@ from infrastructure import (
36
  RecommendedAction,
37
  )
38
 
 
 
 
39
  # ============== CONFIGURATION (Pydantic V2) ==============
40
  class Settings(BaseSettings):
41
  """Application settings loaded from environment variables."""
@@ -257,6 +260,44 @@ class BayesianRiskEngine:
257
  except sqlite3.Error as e:
258
  logger.error(f"Failed to record outcome: {e}")
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  class PolicyEngine:
261
  """Deterministic OSS policies – advisory only."""
262
  def __init__(self):
@@ -399,6 +440,18 @@ class RAGMemory:
399
  self.embedding_cache[text] = embedding
400
  return embedding
401
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  def _init_db(self):
403
  try:
404
  with self._get_db() as conn:
@@ -416,6 +469,15 @@ class RAGMemory:
416
  embedding TEXT
417
  )
418
  ''')
 
 
 
 
 
 
 
 
 
419
  conn.execute('''
420
  CREATE TABLE IF NOT EXISTS signals (
421
  id TEXT PRIMARY KEY,
@@ -449,17 +511,19 @@ class RAGMemory:
449
  conn.close()
450
 
451
  def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
452
- confidence: float, allowed: bool, gates: List[Dict]):
 
 
453
  action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
454
- # Build a descriptive text and generate embedding
455
  incident_text = self._build_incident_text(action)
456
  embedding = json.dumps(self._simple_embedding(incident_text))
457
  try:
458
  with self._get_db() as conn:
459
  conn.execute('''
460
  INSERT INTO incidents
461
- (id, action, action_hash, risk_score, risk_level, confidence, allowed, gates, timestamp, embedding)
462
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 
463
  ''', (
464
  str(uuid.uuid4()),
465
  action[:500],
@@ -470,7 +534,13 @@ class RAGMemory:
470
  1 if allowed else 0,
471
  json.dumps(gates),
472
  datetime.utcnow().isoformat(),
473
- embedding
 
 
 
 
 
 
474
  ))
475
  conn.commit()
476
  except sqlite3.Error as e:
@@ -663,6 +733,23 @@ class InfrastructureEvaluationResponse(BaseModel):
663
  confidence_score: float
664
  evaluation_details: Dict[str, Any]
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  # ============== FASTAPI APP ==============
667
  app = FastAPI(
668
  title="ARF OSS Real Engine (API Only)",
@@ -686,6 +773,7 @@ app.add_middleware(
686
  risk_engine = BayesianRiskEngine()
687
  policy_engine = PolicyEngine()
688
  memory = RAGMemory()
 
689
 
690
  # ============== INFRASTRUCTURE SIMULATOR INSTANCE ==============
691
  # Corrected: RegionAllowedPolicy expects 'allowed_regions', not 'regions'
@@ -741,12 +829,28 @@ async def evaluate_action(request: ActionRequest):
741
  "environment": "production",
742
  "user_role": request.user_role,
743
  "backup_available": request.rollbackFeasible,
744
- "requires_human": request.requiresHuman
 
745
  }
746
- risk = risk_engine.calculate_posterior(
747
- action_text=request.proposedAction,
748
- context=context
749
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  policy = policy_engine.evaluate(
751
  action=request.proposedAction,
752
  risk=risk,
@@ -754,6 +858,14 @@ async def evaluate_action(request: ActionRequest):
754
  )
755
  similar = memory.find_similar(request.proposedAction, limit=3)
756
 
 
 
 
 
 
 
 
 
757
  if not policy["allowed"] and risk["score"] > 0.7:
758
  memory.track_enterprise_signal(
759
  signal_type=LeadSignal.HIGH_RISK_BLOCKED,
@@ -778,7 +890,13 @@ async def evaluate_action(request: ActionRequest):
778
  risk_level=risk["level"],
779
  confidence=request.confidenceScore,
780
  allowed=policy["allowed"],
781
- gates=policy["gates"]
 
 
 
 
 
 
782
  )
783
  gates = []
784
  for g in policy["gates"]:
@@ -904,6 +1022,20 @@ async def evaluate_infrastructure_intent(request: InfrastructureIntentRequest):
904
  logger.error(f"Infrastructure evaluation failed: {e}", exc_info=True)
905
  raise HTTPException(500, detail=str(e))
906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907
  # ============== MAIN ENTRY POINT ==============
908
  if __name__ == "__main__":
909
  import uvicorn
 
36
  RecommendedAction,
37
  )
38
 
39
+ # ============== HMC LEARNER IMPORT ==============
40
+ from hmc_learner import train_hmc_model # new import
41
+
42
  # ============== CONFIGURATION (Pydantic V2) ==============
43
  class Settings(BaseSettings):
44
  """Application settings loaded from environment variables."""
 
260
  except sqlite3.Error as e:
261
  logger.error(f"Failed to record outcome: {e}")
262
 
263
+ # ---------- NEW: Enhanced risk using HMC coefficients ----------
264
+ def enhanced_risk(self, action_text: str, context: Dict, hmc_coeffs: Optional[Dict] = None) -> float:
265
+ """
266
+ Compute a risk score using HMC coefficients if available.
267
+ Falls back to simple posterior score if no coefficients.
268
+ """
269
+ if hmc_coeffs is None:
270
+ return self.calculate_posterior(action_text, context)["score"]
271
+
272
+ # Build feature vector (same as in hmc_learner preprocessing)
273
+ action_cat = self.classify_action(action_text)
274
+ # Map category to code using saved mapping (if present)
275
+ cat_mapping = hmc_coeffs.get("action_cat_mapping", {})
276
+ # Invert mapping (category -> code)
277
+ cat_to_code = {v: k for k, v in cat_mapping.items()}
278
+ cat_code = cat_to_code.get(action_cat, 0) # default to 0 if not found
279
+
280
+ env_prod = 1 if context.get('environment') == 'production' else 0
281
+ role_junior = 1 if context.get('user_role') == 'junior' else 0
282
+ hour = datetime.now().hour
283
+ # Use the simple posterior risk as a feature (normalized)
284
+ simple_risk = self.calculate_posterior(action_text, context)["score"]
285
+ confidence = context.get('confidence', 0.85)
286
+
287
+ # Linear predictor from HMC coefficients
288
+ logit = (
289
+ hmc_coeffs.get('α_cat', {}).get('mean', [0])[cat_code] +
290
+ hmc_coeffs.get('β_env', {}).get('mean', 0) * env_prod +
291
+ hmc_coeffs.get('β_role', {}).get('mean', 0) * role_junior +
292
+ hmc_coeffs.get('β_risk', {}).get('mean', 0) * (simple_risk - 0.5) +
293
+ hmc_coeffs.get('β_hour', {}).get('mean', 0) * ((hour - 12) / 12) +
294
+ hmc_coeffs.get('β_conf', {}).get('mean', 0) * (confidence - 0.5)
295
+ )
296
+ # Convert to probability
297
+ prob = 1 / (1 + np.exp(-logit))
298
+ return prob
299
+
300
+
301
  class PolicyEngine:
302
  """Deterministic OSS policies – advisory only."""
303
  def __init__(self):
 
440
  self.embedding_cache[text] = embedding
441
  return embedding
442
 
443
+ def _ensure_columns(self, conn, columns):
444
+ """Add columns to incidents table if they do not exist."""
445
+ cursor = conn.execute("PRAGMA table_info(incidents)")
446
+ existing = [row[1] for row in cursor.fetchall()]
447
+ for col_name, col_type in columns:
448
+ if col_name not in existing:
449
+ try:
450
+ conn.execute(f"ALTER TABLE incidents ADD COLUMN {col_name} {col_type}")
451
+ logger.info(f"Added column {col_name} to incidents table")
452
+ except sqlite3.Error as e:
453
+ logger.error(f"Failed to add column {col_name}: {e}")
454
+
455
  def _init_db(self):
456
  try:
457
  with self._get_db() as conn:
 
469
  embedding TEXT
470
  )
471
  ''')
472
+ # Add new columns if they don't exist
473
+ self._ensure_columns(conn, [
474
+ ('environment', 'TEXT'),
475
+ ('user_role', 'TEXT'),
476
+ ('requires_human', 'BOOLEAN'),
477
+ ('rollback_feasible', 'BOOLEAN'),
478
+ ('hour_of_day', 'INTEGER'),
479
+ ('action_category', 'TEXT')
480
+ ])
481
  conn.execute('''
482
  CREATE TABLE IF NOT EXISTS signals (
483
  id TEXT PRIMARY KEY,
 
511
  conn.close()
512
 
513
  def store_incident(self, action: str, risk_score: float, risk_level: RiskLevel,
514
+ confidence: float, allowed: bool, gates: List[Dict],
515
+ environment: str, user_role: str, requires_human: bool,
516
+ rollback_feasible: bool, hour_of_day: int, action_category: str):
517
  action_hash = hashlib.sha256(action.encode()).hexdigest()[:50]
 
518
  incident_text = self._build_incident_text(action)
519
  embedding = json.dumps(self._simple_embedding(incident_text))
520
  try:
521
  with self._get_db() as conn:
522
  conn.execute('''
523
  INSERT INTO incidents
524
+ (id, action, action_hash, risk_score, risk_level, confidence, allowed, gates, timestamp, embedding,
525
+ environment, user_role, requires_human, rollback_feasible, hour_of_day, action_category)
526
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
527
  ''', (
528
  str(uuid.uuid4()),
529
  action[:500],
 
534
  1 if allowed else 0,
535
  json.dumps(gates),
536
  datetime.utcnow().isoformat(),
537
+ embedding,
538
+ environment,
539
+ user_role,
540
+ 1 if requires_human else 0,
541
+ 1 if rollback_feasible else 0,
542
+ hour_of_day,
543
+ action_category
544
  ))
545
  conn.commit()
546
  except sqlite3.Error as e:
 
733
  confidence_score: float
734
  evaluation_details: Dict[str, Any]
735
 
736
+ # ============== GLOBAL HMC MODEL DATA ==============
737
+ hmc_model_data = None
738
+
739
+ def load_hmc_model():
740
+ global hmc_model_data
741
+ model_path = f"{settings.data_dir}/hmc_model.json"
742
+ if os.path.exists(model_path):
743
+ try:
744
+ with open(model_path, 'r') as f:
745
+ hmc_model_data = json.load(f)
746
+ logger.info("HMC model loaded successfully")
747
+ except Exception as e:
748
+ logger.error(f"Failed to load HMC model: {e}")
749
+ hmc_model_data = None
750
+ else:
751
+ logger.info("No HMC model found; using default risk engine")
752
+
753
  # ============== FASTAPI APP ==============
754
  app = FastAPI(
755
  title="ARF OSS Real Engine (API Only)",
 
773
  risk_engine = BayesianRiskEngine()
774
  policy_engine = PolicyEngine()
775
  memory = RAGMemory()
776
+ load_hmc_model() # Load HMC model after memory init
777
 
778
  # ============== INFRASTRUCTURE SIMULATOR INSTANCE ==============
779
  # Corrected: RegionAllowedPolicy expects 'allowed_regions', not 'regions'
 
829
  "environment": "production",
830
  "user_role": request.user_role,
831
  "backup_available": request.rollbackFeasible,
832
+ "requires_human": request.requiresHuman,
833
+ "confidence": request.confidenceScore # added for enhanced_risk
834
  }
835
+ # Use HMC-enhanced risk if available
836
+ if hmc_model_data:
837
+ risk_score_val = risk_engine.enhanced_risk(request.proposedAction, context, hmc_model_data)
838
+ # Convert to a risk dict compatible with policy engine (needs level and interval)
839
+ # For simplicity, reuse the simple engine's level mapping based on enhanced score
840
+ risk = risk_engine.calculate_posterior(request.proposedAction, context)
841
+ risk["score"] = risk_score_val
842
+ if risk_score_val > 0.8:
843
+ risk["level"] = RiskLevel.CRITICAL
844
+ elif risk_score_val > 0.6:
845
+ risk["level"] = RiskLevel.HIGH
846
+ elif risk_score_val > 0.4:
847
+ risk["level"] = RiskLevel.MEDIUM
848
+ else:
849
+ risk["level"] = RiskLevel.LOW
850
+ # Recalculate credible interval? We'll keep the simple one for now.
851
+ else:
852
+ risk = risk_engine.calculate_posterior(request.proposedAction, context)
853
+
854
  policy = policy_engine.evaluate(
855
  action=request.proposedAction,
856
  risk=risk,
 
858
  )
859
  similar = memory.find_similar(request.proposedAction, limit=3)
860
 
861
+ # Capture additional fields for logging
862
+ environment = context["environment"]
863
+ user_role = request.user_role
864
+ requires_human = request.requiresHuman
865
+ rollback_feasible = request.rollbackFeasible
866
+ hour_of_day = datetime.now().hour
867
+ action_category = risk_engine.classify_action(request.proposedAction)
868
+
869
  if not policy["allowed"] and risk["score"] > 0.7:
870
  memory.track_enterprise_signal(
871
  signal_type=LeadSignal.HIGH_RISK_BLOCKED,
 
890
  risk_level=risk["level"],
891
  confidence=request.confidenceScore,
892
  allowed=policy["allowed"],
893
+ gates=policy["gates"],
894
+ environment=environment,
895
+ user_role=user_role,
896
+ requires_human=requires_human,
897
+ rollback_feasible=rollback_feasible,
898
+ hour_of_day=hour_of_day,
899
+ action_category=action_category
900
  )
901
  gates = []
902
  for g in policy["gates"]:
 
1022
  logger.error(f"Infrastructure evaluation failed: {e}", exc_info=True)
1023
  raise HTTPException(500, detail=str(e))
1024
 
1025
+ # ============== NEW HMC TRAINING ENDPOINT ==============
1026
+ @app.post("/api/v1/admin/train_hmc", dependencies=[Depends(verify_api_key)])
1027
+ async def train_hmc():
1028
+ """Trigger HMC training on historical incident data."""
1029
+ global hmc_model_data
1030
+ try:
1031
+ db_path = f"{settings.data_dir}/memory.db"
1032
+ model_data = train_hmc_model(db_path, output_dir=settings.data_dir)
1033
+ hmc_model_data = model_data
1034
+ return {"status": "success", "message": "HMC model trained and loaded", "coefficients": model_data.get("coefficients")}
1035
+ except Exception as e:
1036
+ logger.error(f"HMC training failed: {e}", exc_info=True)
1037
+ raise HTTPException(status_code=500, detail=str(e))
1038
+
1039
  # ============== MAIN ENTRY POINT ==============
1040
  if __name__ == "__main__":
1041
  import uvicorn