File size: 4,334 Bytes
32a197f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
KernelX Intelligence Layer — RL Environment (OpenEnv structure)
Provides reset/step interface for training the Strategist policy via GRPO.
Replays recorded transitions from the preprocessed JSONL dataset and
computes multi-objective rewards.
"""
import json
import random
from dataclasses import dataclass, field
from typing import List, Tuple
from .rewards import RewardComputer
@dataclass
class KernelState:
"""Observation wrapper for the RL environment."""
features: List[float] # active features (10D after preprocessing)
pid: int
cpu: int
timestep: int
prev_action: float
@dataclass
class KernelAction:
"""Action output from the Strategist."""
value: float # scheduling weight in [-1.0, 1.0]
class KernelSchedulerEnv:
"""Offline RL environment that replays recorded kernel transitions.
Each episode starts at a random position in the dataset and runs for
max_steps transitions. The reward is computed from the multi-objective
RewardComputer.
"""
def __init__(
self,
data_path: str = "training/data/train.jsonl",
max_steps: int = 10,
alpha: float = 1.0,
beta: float = 2.0,
gamma: float = 0.5,
):
self.records = [json.loads(l) for l in open(data_path) if l.strip()]
self.max_steps = max_steps
self.reward_computer = RewardComputer(alpha=alpha, beta=beta, gamma=gamma)
# Episode state
self.timestep = 0
self.current_idx = 0
self.prev_action = 0.0
if len(self.records) < max_steps + 1:
raise ValueError(
f"Dataset has {len(self.records)} records but max_steps={max_steps} "
f"requires at least {max_steps + 1}"
)
def reset(self) -> KernelState:
"""Start a fresh episode from a random point in the dataset."""
self.timestep = 0
self.current_idx = random.randint(0, len(self.records) - self.max_steps - 1)
self.prev_action = 0.0
return self._get_state()
def step(self, action: KernelAction) -> Tuple[KernelState, dict, bool]:
"""Apply action, compute reward, advance to next state.
Returns:
next_state: The new KernelState after the transition
reward_breakdown: Dict with 'total' and per-component rewards
done: Whether the episode has ended
"""
current = self.records[self.current_idx + self.timestep]
next_idx = self.current_idx + self.timestep + 1
next_rec = self.records[next_idx] if next_idx < len(self.records) else current
reward_breakdown = self.reward_computer.compute_total(
state=current["state"],
action=action,
prev_action=self.prev_action,
next_state=next_rec["state"],
)
self.timestep += 1
self.prev_action = action.value
done = self.timestep >= self.max_steps
return self._get_state(), reward_breakdown, done
def _get_state(self) -> KernelState:
"""Read the current state from the dataset."""
rec = self.records[self.current_idx + self.timestep]
return KernelState(
features=rec["state"],
pid=rec["pid"],
cpu=rec["cpu"],
timestep=self.timestep,
prev_action=self.prev_action,
)
def simulate(self, state_features: list, action_value: float) -> list:
"""Lightweight next-state lookup for reward_fn during GRPO.
Finds the nearest recorded state in the dataset and returns
its recorded next_state. This is a simple approximation;
the World Model provides higher-fidelity simulation.
"""
import numpy as np
state_arr = np.array(state_features)
best_dist = float("inf")
best_next = state_features # fallback
# Sample a subset to keep this fast
sample_size = min(500, len(self.records))
indices = random.sample(range(len(self.records)), sample_size)
for idx in indices:
rec = self.records[idx]
dist = float(np.linalg.norm(state_arr - np.array(rec["state"])))
if dist < best_dist:
best_dist = dist
best_next = rec["next_state"]
return best_next
|