""" Graph-grounded DRL environment. Replaces the 10 hardcoded archetypes in EAEnvironment with capabilities loaded live from Neo4j for a specific domain. Maintains the same STATE_DIM=20 / ACTION_DIM=10 interface so REINFORCETrainer works unchanged. """ import numpy as np from typing import NamedTuple from backend.graph.cypher_queries import ( GET_CAPABILITIES_FOR_TRAINING, GET_DOMAIN_RELATIONSHIP_FLAGS, ) class EAScenario(NamedTuple): cap_business_values: np.ndarray cap_effort_scores: np.ndarray cap_risk_scores: np.ndarray dependency_matrix: np.ndarray budget_capacity: float timeline_score: float risk_tolerance: float _COMPLEXITY_TO_BV = {"low": 0.95, "medium": 0.80, "high": 0.65, "very_high": 0.50} _COMPLEXITY_TO_EFFORT = {"low": 0.30, "medium": 0.55, "high": 0.75, "very_high": 0.90} def _build_bv_effort_risk(caps: list[dict]) -> tuple[np.ndarray, np.ndarray, np.ndarray]: bv, effort, risk = [], [], [] for c in caps: cx = (c.get("complexity") or "medium").lower() bv.append(_COMPLEXITY_TO_BV.get(cx, 0.80)) dur = c.get("duration_weeks") or 16 effort.append(min(float(dur) / 36.0, 1.0)) rf = c.get("risk_factors") or [] risk.append(min(len(rf) / 5.0, 0.90)) # pad to exactly 10 while len(bv) < 10: bv.append(0.70); effort.append(0.50); risk.append(0.30) return ( np.array(bv[:10], dtype=np.float32), np.array(effort[:10], dtype=np.float32), np.array(risk[:10], dtype=np.float32), ) class GraphEAEnvironment: """DRL environment grounded in real Neo4j domain capabilities.""" STATE_DIM = 20 ACTION_DIM = 10 def __init__(self, neo4j_client, domain_name: str, noise_scale: float = 0.08, seed: int | None = None): self._client = neo4j_client self._domain_name = domain_name self._noise = noise_scale self._rng = np.random.default_rng(seed) caps = self._client.run_query(GET_CAPABILITIES_FOR_TRAINING, domain_name=domain_name) self._caps = caps or [] self._cap_names = [c.get("name", f"Cap-{i}") for i, c in enumerate(self._caps)] self._base_bv, self._base_ef, self._base_ri = _build_bv_effort_risk(self._caps) # Domain relationship flags (7 dims) flags_rows = self._client.run_query(GET_DOMAIN_RELATIONSHIP_FLAGS, domain_name=domain_name) if flags_rows: f = flags_rows[0] self._domain_flags = np.array([ float(f.get("is_sector_hub", False)), float(f.get("is_enabled", False)), float(f.get("is_orchestrator", False)), float(f.get("is_governed", False)), float(f.get("is_sector_child", False)), float(f.get("enables_others", False)), float(f.get("has_trend", False)), ], dtype=np.float32) else: self._domain_flags = np.zeros(7, dtype=np.float32) self.scenario: EAScenario | None = None self.current_step = 0 # ------------------------------------------------------------------ def reset(self) -> np.ndarray: noise = self._rng.uniform(-self._noise, self._noise, 10) bv = np.clip(self._base_bv + noise, 0.1, 1.0) ef = np.clip(self._base_ef + self._rng.uniform(-self._noise, self._noise, 10), 0.1, 1.0) ri = np.clip(self._base_ri + self._rng.uniform(-self._noise / 2, self._noise / 2, 10), 0.05, 0.9) dep_matrix = np.zeros((10, 10), dtype=np.float32) # Light dependency: later capabilities often depend on earlier ones for i in range(min(len(self._caps) - 1, 9)): if self._rng.random() > 0.7: dep_matrix[i, i + 1] = 1.0 budget_capacity = float(self._rng.choice([0.4, 0.6, 0.8, 1.0])) timeline_score = float(self._rng.choice([6, 12, 18, 24, 36])) / 36.0 risk_tolerance = float(self._rng.choice([0.33, 0.67, 1.0])) self.scenario = EAScenario(bv, ef, ri, dep_matrix, budget_capacity, timeline_score, risk_tolerance) self.current_step = 0 return self._state_vector() def _state_vector(self) -> np.ndarray: s = self.scenario return np.concatenate([ s.cap_business_values, [s.budget_capacity], [s.timeline_score], [s.risk_tolerance], self._domain_flags, ]).astype(np.float32) def step(self, action_indices: np.ndarray) -> tuple[np.ndarray, float, bool]: s = self.scenario base_reward = sum( s.cap_business_values[idx] * (1.0 - rank / len(action_indices)) for rank, idx in enumerate(action_indices) ) / len(action_indices) dep_violations = sum( 1 for i, di in enumerate(action_indices) for j, dj in enumerate(action_indices) if s.dependency_matrix[dj, di] == 1.0 and j < i ) dep_penalty = dep_violations * 0.15 cum_effort, budget_penalty = 0.0, 0.0 for idx in action_indices: cum_effort += s.cap_effort_scores[idx] / 10.0 if cum_effort > s.budget_capacity: budget_penalty += 0.05 risk_penalty = sum( s.cap_risk_scores[idx] * 0.2 for idx in action_indices[:3] if s.cap_risk_scores[idx] > s.risk_tolerance ) reward = float(max(-1.0, min(2.0, base_reward - dep_penalty - budget_penalty - risk_penalty))) self.current_step += 1 return self._state_vector(), reward, True def sample_action(self) -> np.ndarray: return self._rng.permutation(self.ACTION_DIM).astype(np.int64) def get_domain_name(self) -> str: return self._domain_name def get_capability_names(self) -> list[str]: return self._cap_names