npm-resolver / env.py
ArpitBaliyan's picture
Upload 5 files
b20771e verified
"""
NPMResolverEnv — Robust, trainable environment for RL-based npm dependency resolution.
Combines the exhaustive safety / state-management logic from the production-grade
environment.py with the clean, easy-to-train reward shaping described in the PRD.
"""
from __future__ import annotations
import copy
import hashlib
import json
import random
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# 1. Action / Observation containers (kept lean but validated)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Action:
package_to_update: str
new_version: str
@classmethod
def from_json(cls, data: Dict[str, Any]) -> "Action":
if not isinstance(data, dict):
raise TypeError("Action must be a JSON object.")
pkg = data.get("package_to_update")
ver = data.get("new_version")
if pkg is None or ver is None:
raise ValueError("Action requires 'package_to_update' and 'new_version'.")
if not isinstance(pkg, str) or not isinstance(ver, str):
raise TypeError("Both fields must be strings.")
pkg, ver = pkg.strip(), ver.strip()
if not pkg or not ver:
raise ValueError("Fields cannot be empty.")
return cls(package_to_update=pkg, new_version=ver)
def to_dict(self) -> Dict[str, str]:
return {"package_to_update": self.package_to_update, "new_version": self.new_version}
@dataclass
class Observation:
current_package_json: str
npm_error_log: str
step_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"current_package_json": self.current_package_json,
"npm_error_log": self.npm_error_log,
"step_count": self.step_count,
}
# ---------------------------------------------------------------------------
# 2. The environment
# ---------------------------------------------------------------------------
class NPMResolverEnv:
# ------ internal constants ------
_VERSION_RE = re.compile(r"^(?:\^\d+\.\d+\.\d+|\d+\.\d+\.\d+|DELETE)$")
SUCCESS_MSG = "SUCCESS: Audited packages in 0.01s. No vulnerabilities found."
UNTOUCHABLE = {"react", "react-dom"}
# ------ reward values (match PRD and the “trainability” of the simple env) ------
STEP_PENALTY = -1
INVALID_ACTION_PENALTY = -5
PROGRESS_REWARD = 10 # per resolved conflict
REGRESSION_PENALTY = -10 # per introduced conflict
SUCCESS_REWARD = 50
FATAL_PENALTY = -100
TIMEOUT_PENALTY = -10
NOOP_PENALTY = -2
CORE_DELETE_PENALTY = -100
# ------ registry mock ------
MOCK_REGISTRY: Dict[str, Dict[str, Dict[str, Dict[str, str]]]] = {
"react-dom": {
"^17.0.0": {"requires": {"react": "^17.0.0"}},
"^18.0.0": {"requires": {"react": "^18.0.0"}},
},
"react-router-dom": {
"6.0.0": {"requires": {"react": "^18.0.0", "react-dom": "^18.0.0"}},
},
"mongoose": {
"6.0.0": {"requires": {"mongodb": "^4.0.0"}},
"7.0.0": {"requires": {"mongodb": "^5.0.0"}},
},
"express-session": {
"1.17.0": {"requires": {"express": "^4.0.0"}},
},
"babel-loader": {
"8.0.0": {"requires": {"webpack": "^4.0.0", "@babel/core": "^7.0.0"}},
"9.0.0": {"requires": {"webpack": "^5.0.0", "@babel/core": "^7.0.0"}},
},
"vuex": {
"3.0.0": {"requires": {"vue": "^2.0.0"}},
"4.0.0": {"requires": {"vue": "^3.0.0"}},
}
}
# ------ curriculum scenarios ------
SCENARIOS = {
"level_1": [
{"react": "^17.0.0", "react-dom": "^18.0.0"},
{"mongodb": "^4.0.0", "mongoose": "7.0.0"},
{"vue": "^2.0.0", "vuex": "4.0.0"}
],
"level_2": [
{"react": "^17.0.0", "react-dom": "^17.0.0", "react-router-dom": "6.0.0"},
{"webpack": "^4.0.0", "@babel/core": "^7.0.0", "babel-loader": "9.0.0"},
],
"level_3": [
{"webpack": "^4.0.0", "@babel/core": "^6.0.0", "babel-loader": "9.0.0"},
{"express": "^3.0.0", "express-session": "1.17.0", "mongodb": "^4.0.0", "mongoose": "7.0.0"}
],
}
def __init__(self, training_mode: bool = True, max_steps: int = 10):
"""
training_mode:
- True → each reset picks a random scenario from *all* levels (great for RL).
- False → curriculum progression (level‑up after streak of successes).
"""
self.training_mode = training_mode
self.max_steps = max_steps
self.registry = copy.deepcopy(self.MOCK_REGISTRY)
self.scenarios = {
name: [dict(sorted(s.items())) for s in scens]
for name, scens in self.SCENARIOS.items()
}
self.allowed_packages = self._build_allowed_set()
# curriculum state (only used when training_mode=False)
self.current_level = "level_1"
self.level_success_streak = 0
self.level_success_threshold = 2
self.level_indices = {name: 0 for name in self.scenarios}
# per‑episode state
self.current_state: Dict[str, str] = {}
self.step_count = 0
self.last_action: Optional[Tuple[str, str]] = None
self.episode_done = False
self.active_level = "level_1"
# full history (optional, for metrics / replay)
self.history: List[Dict[str, Any]] = []
self.episode_count = 0
self.total_reward = 0.0
# ------------------------------------------------------------------
# 3. Core logic (ported from the robust env with simplified rewards)
# ------------------------------------------------------------------
def _build_allowed_set(self) -> List[str]:
pkgs = set(self.UNTOUCHABLE)
for scens in self.scenarios.values():
for s in scens:
pkgs.update(s.keys())
for pkg, versions in self.registry.items():
pkgs.add(pkg)
for meta in versions.values():
for dep in meta.get("requires", {}):
pkgs.add(dep)
return sorted(pkgs)
def _extract_major(self, version: str) -> Optional[int]:
v = version[1:] if version.startswith("^") else version
try:
return int(v.split(".", 1)[0])
except (ValueError, IndexError):
return None
def is_compatible(self, required: str, actual: str) -> bool:
if required == actual:
return True
rm = self._extract_major(required)
am = self._extract_major(actual)
return rm is not None and am is not None and rm == am
def _generate_error_log(self, deps: Dict[str, str]) -> str:
"""Return SUCCESS_MSG or human‑readable conflict lines."""
if not isinstance(deps, dict):
return "FATAL: Invalid state."
conflicts = []
for pkg, ver in sorted(deps.items()):
entry = self.registry.get(pkg, {}).get(ver)
if not entry:
continue
for req, req_ver in sorted(entry.get("requires", {}).items()):
cur = deps.get(req)
if cur is None:
conflicts.append(
f"Missing peer dependency: {pkg}@{ver} requires {req}@{req_ver}"
)
elif not self.is_compatible(req_ver, cur):
conflicts.append(
f"Conflict: {pkg}@{ver} requires {req}@{req_ver}, but found {req}@{cur}"
)
return "\n".join(conflicts) if conflicts else self.SUCCESS_MSG
def _error_count(self, log: str) -> int:
return 0 if log == self.SUCCESS_MSG else log.count("\n") + 1
def _validate_action(self, action: Action) -> Optional[str]:
"""Return error message or None if valid."""
if not isinstance(action, Action):
return "Error: Action must be an Action instance."
if not isinstance(action.package_to_update, str) or not isinstance(action.new_version, str):
return "Error: Action fields must be strings."
pkg, ver = action.package_to_update.strip(), action.new_version.strip()
if not pkg or not ver:
return "Error: Action fields cannot be empty."
if pkg not in self.current_state:
return f"Error: Package '{pkg}' not in current state."
if not self._VERSION_RE.match(ver):
return f"Error: Invalid version format '{ver}'."
if ver != "DELETE" and pkg in self.registry and ver not in self.registry[pkg]:
return f"Error: Unknown version '{ver}' for '{pkg}'."
return None
def _apply_transition(self, state: Dict[str, str], action: Action) -> Tuple[Dict[str, str], Optional[str]]:
"""Try to apply `action` to a copy of `state`. Returns (new_state, error)."""
if not isinstance(state, dict):
return state, "FATAL: Invalid state."
s = dict(sorted(state.items()))
pkg, ver = action.package_to_update.strip(), action.new_version.strip()
if ver == "DELETE":
if pkg in self.UNTOUCHABLE:
return s, "FATAL: Cannot remove a core framework package."
del s[pkg]
else:
s[pkg] = ver
# invariants
if len(s) < 2:
return s, "FATAL: State must contain at least 2 packages."
return dict(sorted(s.items())), None
# ------------------------------------------------------------------
# 4. Gym‑like interface (reset, step)
# ------------------------------------------------------------------
def reset(self) -> Observation:
"""Start a fresh episode (random scenario if training_mode, else curriculum)."""
self.step_count = 0
self.last_action = None
self.episode_done = False
self.episode_count += 1
if self.training_mode:
# pick a random scenario from any level
all_scenarios = [sc for scens in self.scenarios.values() for sc in scens]
self.current_state = dict(random.choice(all_scenarios))
self.active_level = "mixed"
else:
level = self.current_level
scenarios = self.scenarios[level]
idx = self.level_indices[level] % len(scenarios)
self.level_indices[level] = idx + 1
self.current_state = dict(scenarios[idx])
self.active_level = level
errors = self._generate_error_log(self.current_state)
return Observation(
current_package_json=json.dumps(self.current_state, indent=2, sort_keys=True),
npm_error_log=errors,
step_count=0,
)
def step(self, action: Action) -> Tuple[Observation, int, bool, Dict[str, Any]]:
"""
Returns (observation, reward, done, info).
Info dict carries metadata useful for debugging & training.
"""
if self.episode_done:
# episode already terminated – return last observation unchanged
obs = Observation(
json.dumps(self.current_state, indent=2, sort_keys=True),
self._generate_error_log(self.current_state),
self.step_count,
)
return obs, 0, True, {"status": "done", "reason": "episode_already_terminated"}
self.step_count += 1
reward = 0
done = False
info: Dict[str, Any] = {"status": "in_progress"}
# 4a. timeout check
if self.step_count > self.max_steps:
log = f"TIMEOUT: Max steps ({self.max_steps}) reached.\n{self._generate_error_log(self.current_state)}"
obs = Observation(json.dumps(self.current_state, indent=2, sort_keys=True), log, self.step_count)
reward += self.TIMEOUT_PENALTY
done = True
info["status"] = "timeout"
self.episode_done = True
self.total_reward += reward
return obs, reward, done, info
# 4b. validate action
error_msg = self._validate_action(action)
if error_msg:
# invalid action – penalise and continue (episode NOT terminated)
reward += self.INVALID_ACTION_PENALTY
obs = Observation(
json.dumps(self.current_state, indent=2, sort_keys=True),
error_msg,
self.step_count,
)
info["status"] = "invalid_action"
info["error"] = error_msg
self.total_reward += reward
return obs, reward, done, info
pkg, ver = action.package_to_update.strip(), action.new_version.strip()
# 4c. repeated useless action check
if (pkg, ver) == self.last_action:
reward += self.NOOP_PENALTY
self.last_action = (pkg, ver)
old_log = self._generate_error_log(self.current_state)
old_errors = self._error_count(old_log)
# 4d. apply transition
new_state, err = self._apply_transition(self.current_state, action)
if err:
# fatal errors (e.g. core delete) immediately end episode
obs = Observation(
json.dumps(self.current_state, indent=2, sort_keys=True),
err,
self.step_count,
)
reward += self.FATAL_PENALTY
done = True
info["status"] = "failed"
info["reason"] = err
self.episode_done = True
self.total_reward += reward
return obs, reward, done, info
# commit state change
self.current_state = dict(sorted(new_state.items()))
# 4e. step penalty
reward += self.STEP_PENALTY
new_log = self._generate_error_log(self.current_state)
new_errors = self._error_count(new_log)
# 4f. progress / regression reward
progress = old_errors - new_errors
if progress > 0:
reward += progress * self.PROGRESS_REWARD
elif progress < 0:
reward += progress * self.REGRESSION_PENALTY # (negative * negative = negative)
# 4g. ultimate success
if new_log == self.SUCCESS_MSG:
reward += self.SUCCESS_REWARD
done = True
info["status"] = "success"
self.episode_done = True
if not self.training_mode:
self.episode_done = True
self.level_success_streak += 1
if self.level_success_streak >= self.level_success_threshold:
# advance curriculum
levels = ["level_1", "level_2", "level_3"]
idx = levels.index(self.current_level) if self.current_level in levels else 0
if idx < len(levels) - 1:
self.current_level = levels[idx + 1]
self.level_indices[self.current_level] = 0
self.level_success_streak = 0
obs = Observation(
json.dumps(self.current_state, indent=2, sort_keys=True),
new_log,
self.step_count,
)
info["errors_remaining"] = new_errors
self.total_reward += reward
return obs, reward, done, info
# ------------------------------------------------------------------
# 5. Auxiliary helpers (metrics, serialisation, action mask)
# ------------------------------------------------------------------
def get_metrics(self) -> Dict[str, float]:
return {
"episodes": self.episode_count,
"total_reward": self.total_reward,
"active_level": self.active_level,
}
def get_action_mask(self) -> Dict[str, List[str]]:
"""Return valid {package: [version]} for the current state (useful for RL)."""
mask: Dict[str, List[str]] = {}
base_state = dict(sorted(self.current_state.items()))
for pkg in base_state:
# allowed versions from registry + DELETE for non‑core
versions = []
if pkg in self.registry:
versions = list(self.registry[pkg].keys())
else:
# e.g. “react” itself – we only allow versions that appear in scenarios
versions = sorted({
s.get(pkg) for scens in self.scenarios.values() for s in scens if pkg in s
})
# include current version for no‑op detection (but it will be penalised)
current_ver = base_state[pkg]
candidate_vers = [v for v in versions if v != current_ver]
if pkg not in self.UNTOUCHABLE and len(base_state) > 2:
candidate_vers.append("DELETE")
# only keep versions that don't break the state
valid = []
for v in candidate_vers:
act = Action(package_to_update=pkg, new_version=v)
test_state, err = self._apply_transition(base_state, act)
if err is None and test_state is not None:
valid.append(v)
if valid:
mask[pkg] = valid
return mask
def to_dict(self) -> Dict[str, Any]:
"""Lightweight serialization (enough for checkpointing)."""
return {
"current_state": self.current_state,
"step_count": self.step_count,
"max_steps": self.max_steps,
"episode_done": self.episode_done,
"active_level": self.active_level,
"training_mode": self.training_mode,
"current_level": self.current_level,
"level_success_streak": self.level_success_streak,
}
def from_dict(self, state: Dict[str, Any]) -> None:
self.current_state = dict(state["current_state"])
self.step_count = int(state["step_count"])
self.max_steps = int(state["max_steps"])
self.episode_done = bool(state["episode_done"])
self.active_level = str(state["active_level"])
self.training_mode = bool(state["training_mode"])
self.current_level = str(state["current_level"])
self.level_success_streak = int(state["level_success_streak"])
self.last_action = None # reset
# ---------------------------------------------------------------------------
# 6. Quick local test harness (run `python environment.py` to verify)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("=== Robust NPMResolverEnv self‑test ===")
# Using training_mode=False ensures we reliably test the Level 1 'react' scenario
# instead of randomly hitting a 'vue' or 'mongoose' scenario and failing the assertions.
env = NPMResolverEnv(training_mode=False, max_steps=10)
obs = env.reset()
print("Initial state:", obs.current_package_json[:80], "...")
# Test 1: invalid format
_, r, d, info = env.step(Action("react-dom", "garbage"))
assert r == env.INVALID_ACTION_PENALTY and d == False
print("[PASS] Invalid version penalised correctly.")
# Test 2: unknown package
obs2, r2, d2, _ = env.step(Action("made-up", "1.0.0"))
assert r2 == env.INVALID_ACTION_PENALTY
print("[PASS] Unknown package penalised correctly.")
# Test 3: core delete
env2 = NPMResolverEnv(training_mode=False)
env2.reset()
obs3, r3, d3, info3 = env2.step(Action("react", "DELETE"))
assert r3 == env.FATAL_PENALTY and d3 == True
print("[PASS] Core deletion triggers fatal penalty.")
# Test 4: success trajectory
env3 = NPMResolverEnv(training_mode=False)
env3.reset()
obs4, r4, d4, _ = env3.step(Action("react", "^18.0.0"))
# after this step, remaining conflicts may require react-dom bump -> need at least 2 steps
if env3._generate_error_log(env3.current_state) == env3.SUCCESS_MSG:
assert d4 and r4 > 0
print("[PASS] Fixed in 1 step (expected for level_1 simple mismatch).")
else:
assert not d4
obs5, r5, d5, _ = env3.step(Action("react-dom", "^18.0.0"))
assert d5 and r5 > 30 # should include success + progress
print("[PASS] Two‑step fix rewarded correctly.")
print("\nAll tests passed – environment is bulletproof and trainable.")