""" NPMResolverEnv — Robust, trainable environment for RL-based npm dependency resolution. Combines the exhaustive safety / state-management logic from the production-grade environment.py with the clean, easy-to-train reward shaping described in the PRD. """ from __future__ import annotations import copy import hashlib import json import random import re from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple # --------------------------------------------------------------------------- # 1. Action / Observation containers (kept lean but validated) # --------------------------------------------------------------------------- @dataclass(frozen=True) class Action: package_to_update: str new_version: str @classmethod def from_json(cls, data: Dict[str, Any]) -> "Action": if not isinstance(data, dict): raise TypeError("Action must be a JSON object.") pkg = data.get("package_to_update") ver = data.get("new_version") if pkg is None or ver is None: raise ValueError("Action requires 'package_to_update' and 'new_version'.") if not isinstance(pkg, str) or not isinstance(ver, str): raise TypeError("Both fields must be strings.") pkg, ver = pkg.strip(), ver.strip() if not pkg or not ver: raise ValueError("Fields cannot be empty.") return cls(package_to_update=pkg, new_version=ver) def to_dict(self) -> Dict[str, str]: return {"package_to_update": self.package_to_update, "new_version": self.new_version} @dataclass class Observation: current_package_json: str npm_error_log: str step_count: int def to_dict(self) -> Dict[str, Any]: return { "current_package_json": self.current_package_json, "npm_error_log": self.npm_error_log, "step_count": self.step_count, } # --------------------------------------------------------------------------- # 2. The environment # --------------------------------------------------------------------------- class NPMResolverEnv: # ------ internal constants ------ _VERSION_RE = re.compile(r"^(?:\^\d+\.\d+\.\d+|\d+\.\d+\.\d+|DELETE)$") SUCCESS_MSG = "SUCCESS: Audited packages in 0.01s. No vulnerabilities found." UNTOUCHABLE = {"react", "react-dom"} # ------ reward values (match PRD and the “trainability” of the simple env) ------ STEP_PENALTY = -1 INVALID_ACTION_PENALTY = -5 PROGRESS_REWARD = 10 # per resolved conflict REGRESSION_PENALTY = -10 # per introduced conflict SUCCESS_REWARD = 50 FATAL_PENALTY = -100 TIMEOUT_PENALTY = -10 NOOP_PENALTY = -2 CORE_DELETE_PENALTY = -100 # ------ registry mock ------ MOCK_REGISTRY: Dict[str, Dict[str, Dict[str, Dict[str, str]]]] = { "react-dom": { "^17.0.0": {"requires": {"react": "^17.0.0"}}, "^18.0.0": {"requires": {"react": "^18.0.0"}}, }, "react-router-dom": { "6.0.0": {"requires": {"react": "^18.0.0", "react-dom": "^18.0.0"}}, }, "mongoose": { "6.0.0": {"requires": {"mongodb": "^4.0.0"}}, "7.0.0": {"requires": {"mongodb": "^5.0.0"}}, }, "express-session": { "1.17.0": {"requires": {"express": "^4.0.0"}}, }, "babel-loader": { "8.0.0": {"requires": {"webpack": "^4.0.0", "@babel/core": "^7.0.0"}}, "9.0.0": {"requires": {"webpack": "^5.0.0", "@babel/core": "^7.0.0"}}, }, "vuex": { "3.0.0": {"requires": {"vue": "^2.0.0"}}, "4.0.0": {"requires": {"vue": "^3.0.0"}}, } } # ------ curriculum scenarios ------ SCENARIOS = { "level_1": [ {"react": "^17.0.0", "react-dom": "^18.0.0"}, {"mongodb": "^4.0.0", "mongoose": "7.0.0"}, {"vue": "^2.0.0", "vuex": "4.0.0"} ], "level_2": [ {"react": "^17.0.0", "react-dom": "^17.0.0", "react-router-dom": "6.0.0"}, {"webpack": "^4.0.0", "@babel/core": "^7.0.0", "babel-loader": "9.0.0"}, ], "level_3": [ {"webpack": "^4.0.0", "@babel/core": "^6.0.0", "babel-loader": "9.0.0"}, {"express": "^3.0.0", "express-session": "1.17.0", "mongodb": "^4.0.0", "mongoose": "7.0.0"} ], } def __init__(self, training_mode: bool = True, max_steps: int = 10): """ training_mode: - True → each reset picks a random scenario from *all* levels (great for RL). - False → curriculum progression (level‑up after streak of successes). """ self.training_mode = training_mode self.max_steps = max_steps self.registry = copy.deepcopy(self.MOCK_REGISTRY) self.scenarios = { name: [dict(sorted(s.items())) for s in scens] for name, scens in self.SCENARIOS.items() } self.allowed_packages = self._build_allowed_set() # curriculum state (only used when training_mode=False) self.current_level = "level_1" self.level_success_streak = 0 self.level_success_threshold = 2 self.level_indices = {name: 0 for name in self.scenarios} # per‑episode state self.current_state: Dict[str, str] = {} self.step_count = 0 self.last_action: Optional[Tuple[str, str]] = None self.episode_done = False self.active_level = "level_1" # full history (optional, for metrics / replay) self.history: List[Dict[str, Any]] = [] self.episode_count = 0 self.total_reward = 0.0 # ------------------------------------------------------------------ # 3. Core logic (ported from the robust env with simplified rewards) # ------------------------------------------------------------------ def _build_allowed_set(self) -> List[str]: pkgs = set(self.UNTOUCHABLE) for scens in self.scenarios.values(): for s in scens: pkgs.update(s.keys()) for pkg, versions in self.registry.items(): pkgs.add(pkg) for meta in versions.values(): for dep in meta.get("requires", {}): pkgs.add(dep) return sorted(pkgs) def _extract_major(self, version: str) -> Optional[int]: v = version[1:] if version.startswith("^") else version try: return int(v.split(".", 1)[0]) except (ValueError, IndexError): return None def is_compatible(self, required: str, actual: str) -> bool: if required == actual: return True rm = self._extract_major(required) am = self._extract_major(actual) return rm is not None and am is not None and rm == am def _generate_error_log(self, deps: Dict[str, str]) -> str: """Return SUCCESS_MSG or human‑readable conflict lines.""" if not isinstance(deps, dict): return "FATAL: Invalid state." conflicts = [] for pkg, ver in sorted(deps.items()): entry = self.registry.get(pkg, {}).get(ver) if not entry: continue for req, req_ver in sorted(entry.get("requires", {}).items()): cur = deps.get(req) if cur is None: conflicts.append( f"Missing peer dependency: {pkg}@{ver} requires {req}@{req_ver}" ) elif not self.is_compatible(req_ver, cur): conflicts.append( f"Conflict: {pkg}@{ver} requires {req}@{req_ver}, but found {req}@{cur}" ) return "\n".join(conflicts) if conflicts else self.SUCCESS_MSG def _error_count(self, log: str) -> int: return 0 if log == self.SUCCESS_MSG else log.count("\n") + 1 def _validate_action(self, action: Action) -> Optional[str]: """Return error message or None if valid.""" if not isinstance(action, Action): return "Error: Action must be an Action instance." if not isinstance(action.package_to_update, str) or not isinstance(action.new_version, str): return "Error: Action fields must be strings." pkg, ver = action.package_to_update.strip(), action.new_version.strip() if not pkg or not ver: return "Error: Action fields cannot be empty." if pkg not in self.current_state: return f"Error: Package '{pkg}' not in current state." if not self._VERSION_RE.match(ver): return f"Error: Invalid version format '{ver}'." if ver != "DELETE" and pkg in self.registry and ver not in self.registry[pkg]: return f"Error: Unknown version '{ver}' for '{pkg}'." return None def _apply_transition(self, state: Dict[str, str], action: Action) -> Tuple[Dict[str, str], Optional[str]]: """Try to apply `action` to a copy of `state`. Returns (new_state, error).""" if not isinstance(state, dict): return state, "FATAL: Invalid state." s = dict(sorted(state.items())) pkg, ver = action.package_to_update.strip(), action.new_version.strip() if ver == "DELETE": if pkg in self.UNTOUCHABLE: return s, "FATAL: Cannot remove a core framework package." del s[pkg] else: s[pkg] = ver # invariants if len(s) < 2: return s, "FATAL: State must contain at least 2 packages." return dict(sorted(s.items())), None # ------------------------------------------------------------------ # 4. Gym‑like interface (reset, step) # ------------------------------------------------------------------ def reset(self) -> Observation: """Start a fresh episode (random scenario if training_mode, else curriculum).""" self.step_count = 0 self.last_action = None self.episode_done = False self.episode_count += 1 if self.training_mode: # pick a random scenario from any level all_scenarios = [sc for scens in self.scenarios.values() for sc in scens] self.current_state = dict(random.choice(all_scenarios)) self.active_level = "mixed" else: level = self.current_level scenarios = self.scenarios[level] idx = self.level_indices[level] % len(scenarios) self.level_indices[level] = idx + 1 self.current_state = dict(scenarios[idx]) self.active_level = level errors = self._generate_error_log(self.current_state) return Observation( current_package_json=json.dumps(self.current_state, indent=2, sort_keys=True), npm_error_log=errors, step_count=0, ) def step(self, action: Action) -> Tuple[Observation, int, bool, Dict[str, Any]]: """ Returns (observation, reward, done, info). Info dict carries metadata useful for debugging & training. """ if self.episode_done: # episode already terminated – return last observation unchanged obs = Observation( json.dumps(self.current_state, indent=2, sort_keys=True), self._generate_error_log(self.current_state), self.step_count, ) return obs, 0, True, {"status": "done", "reason": "episode_already_terminated"} self.step_count += 1 reward = 0 done = False info: Dict[str, Any] = {"status": "in_progress"} # 4a. timeout check if self.step_count > self.max_steps: log = f"TIMEOUT: Max steps ({self.max_steps}) reached.\n{self._generate_error_log(self.current_state)}" obs = Observation(json.dumps(self.current_state, indent=2, sort_keys=True), log, self.step_count) reward += self.TIMEOUT_PENALTY done = True info["status"] = "timeout" self.episode_done = True self.total_reward += reward return obs, reward, done, info # 4b. validate action error_msg = self._validate_action(action) if error_msg: # invalid action – penalise and continue (episode NOT terminated) reward += self.INVALID_ACTION_PENALTY obs = Observation( json.dumps(self.current_state, indent=2, sort_keys=True), error_msg, self.step_count, ) info["status"] = "invalid_action" info["error"] = error_msg self.total_reward += reward return obs, reward, done, info pkg, ver = action.package_to_update.strip(), action.new_version.strip() # 4c. repeated useless action check if (pkg, ver) == self.last_action: reward += self.NOOP_PENALTY self.last_action = (pkg, ver) old_log = self._generate_error_log(self.current_state) old_errors = self._error_count(old_log) # 4d. apply transition new_state, err = self._apply_transition(self.current_state, action) if err: # fatal errors (e.g. core delete) immediately end episode obs = Observation( json.dumps(self.current_state, indent=2, sort_keys=True), err, self.step_count, ) reward += self.FATAL_PENALTY done = True info["status"] = "failed" info["reason"] = err self.episode_done = True self.total_reward += reward return obs, reward, done, info # commit state change self.current_state = dict(sorted(new_state.items())) # 4e. step penalty reward += self.STEP_PENALTY new_log = self._generate_error_log(self.current_state) new_errors = self._error_count(new_log) # 4f. progress / regression reward progress = old_errors - new_errors if progress > 0: reward += progress * self.PROGRESS_REWARD elif progress < 0: reward += progress * self.REGRESSION_PENALTY # (negative * negative = negative) # 4g. ultimate success if new_log == self.SUCCESS_MSG: reward += self.SUCCESS_REWARD done = True info["status"] = "success" self.episode_done = True if not self.training_mode: self.episode_done = True self.level_success_streak += 1 if self.level_success_streak >= self.level_success_threshold: # advance curriculum levels = ["level_1", "level_2", "level_3"] idx = levels.index(self.current_level) if self.current_level in levels else 0 if idx < len(levels) - 1: self.current_level = levels[idx + 1] self.level_indices[self.current_level] = 0 self.level_success_streak = 0 obs = Observation( json.dumps(self.current_state, indent=2, sort_keys=True), new_log, self.step_count, ) info["errors_remaining"] = new_errors self.total_reward += reward return obs, reward, done, info # ------------------------------------------------------------------ # 5. Auxiliary helpers (metrics, serialisation, action mask) # ------------------------------------------------------------------ def get_metrics(self) -> Dict[str, float]: return { "episodes": self.episode_count, "total_reward": self.total_reward, "active_level": self.active_level, } def get_action_mask(self) -> Dict[str, List[str]]: """Return valid {package: [version]} for the current state (useful for RL).""" mask: Dict[str, List[str]] = {} base_state = dict(sorted(self.current_state.items())) for pkg in base_state: # allowed versions from registry + DELETE for non‑core versions = [] if pkg in self.registry: versions = list(self.registry[pkg].keys()) else: # e.g. “react” itself – we only allow versions that appear in scenarios versions = sorted({ s.get(pkg) for scens in self.scenarios.values() for s in scens if pkg in s }) # include current version for no‑op detection (but it will be penalised) current_ver = base_state[pkg] candidate_vers = [v for v in versions if v != current_ver] if pkg not in self.UNTOUCHABLE and len(base_state) > 2: candidate_vers.append("DELETE") # only keep versions that don't break the state valid = [] for v in candidate_vers: act = Action(package_to_update=pkg, new_version=v) test_state, err = self._apply_transition(base_state, act) if err is None and test_state is not None: valid.append(v) if valid: mask[pkg] = valid return mask def to_dict(self) -> Dict[str, Any]: """Lightweight serialization (enough for checkpointing).""" return { "current_state": self.current_state, "step_count": self.step_count, "max_steps": self.max_steps, "episode_done": self.episode_done, "active_level": self.active_level, "training_mode": self.training_mode, "current_level": self.current_level, "level_success_streak": self.level_success_streak, } def from_dict(self, state: Dict[str, Any]) -> None: self.current_state = dict(state["current_state"]) self.step_count = int(state["step_count"]) self.max_steps = int(state["max_steps"]) self.episode_done = bool(state["episode_done"]) self.active_level = str(state["active_level"]) self.training_mode = bool(state["training_mode"]) self.current_level = str(state["current_level"]) self.level_success_streak = int(state["level_success_streak"]) self.last_action = None # reset # --------------------------------------------------------------------------- # 6. Quick local test harness (run `python environment.py` to verify) # --------------------------------------------------------------------------- if __name__ == "__main__": print("=== Robust NPMResolverEnv self‑test ===") # Using training_mode=False ensures we reliably test the Level 1 'react' scenario # instead of randomly hitting a 'vue' or 'mongoose' scenario and failing the assertions. env = NPMResolverEnv(training_mode=False, max_steps=10) obs = env.reset() print("Initial state:", obs.current_package_json[:80], "...") # Test 1: invalid format _, r, d, info = env.step(Action("react-dom", "garbage")) assert r == env.INVALID_ACTION_PENALTY and d == False print("[PASS] Invalid version penalised correctly.") # Test 2: unknown package obs2, r2, d2, _ = env.step(Action("made-up", "1.0.0")) assert r2 == env.INVALID_ACTION_PENALTY print("[PASS] Unknown package penalised correctly.") # Test 3: core delete env2 = NPMResolverEnv(training_mode=False) env2.reset() obs3, r3, d3, info3 = env2.step(Action("react", "DELETE")) assert r3 == env.FATAL_PENALTY and d3 == True print("[PASS] Core deletion triggers fatal penalty.") # Test 4: success trajectory env3 = NPMResolverEnv(training_mode=False) env3.reset() obs4, r4, d4, _ = env3.step(Action("react", "^18.0.0")) # after this step, remaining conflicts may require react-dom bump -> need at least 2 steps if env3._generate_error_log(env3.current_state) == env3.SUCCESS_MSG: assert d4 and r4 > 0 print("[PASS] Fixed in 1 step (expected for level_1 simple mismatch).") else: assert not d4 obs5, r5, d5, _ = env3.step(Action("react-dom", "^18.0.0")) assert d5 and r5 > 30 # should include success + progress print("[PASS] Two‑step fix rewarded correctly.") print("\nAll tests passed – environment is bulletproof and trainable.")