File size: 4,334 Bytes
32a197f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
KernelX Intelligence Layer — RL Environment (OpenEnv structure)

Provides reset/step interface for training the Strategist policy via GRPO.
Replays recorded transitions from the preprocessed JSONL dataset and
computes multi-objective rewards.
"""

import json
import random
from dataclasses import dataclass, field
from typing import List, Tuple

from .rewards import RewardComputer


@dataclass
class KernelState:
    """Observation wrapper for the RL environment."""
    features: List[float]   # active features (10D after preprocessing)
    pid: int
    cpu: int
    timestep: int
    prev_action: float


@dataclass
class KernelAction:
    """Action output from the Strategist."""
    value: float            # scheduling weight in [-1.0, 1.0]


class KernelSchedulerEnv:
    """Offline RL environment that replays recorded kernel transitions.

    Each episode starts at a random position in the dataset and runs for
    max_steps transitions.  The reward is computed from the multi-objective
    RewardComputer.
    """

    def __init__(
        self,
        data_path: str = "training/data/train.jsonl",
        max_steps: int = 10,
        alpha: float = 1.0,
        beta: float = 2.0,
        gamma: float = 0.5,
    ):
        self.records = [json.loads(l) for l in open(data_path) if l.strip()]
        self.max_steps = max_steps
        self.reward_computer = RewardComputer(alpha=alpha, beta=beta, gamma=gamma)

        # Episode state
        self.timestep = 0
        self.current_idx = 0
        self.prev_action = 0.0

        if len(self.records) < max_steps + 1:
            raise ValueError(
                f"Dataset has {len(self.records)} records but max_steps={max_steps} "
                f"requires at least {max_steps + 1}"
            )

    def reset(self) -> KernelState:
        """Start a fresh episode from a random point in the dataset."""
        self.timestep = 0
        self.current_idx = random.randint(0, len(self.records) - self.max_steps - 1)
        self.prev_action = 0.0
        return self._get_state()

    def step(self, action: KernelAction) -> Tuple[KernelState, dict, bool]:
        """Apply action, compute reward, advance to next state.

        Returns:
            next_state: The new KernelState after the transition
            reward_breakdown: Dict with 'total' and per-component rewards
            done: Whether the episode has ended
        """
        current = self.records[self.current_idx + self.timestep]
        next_idx = self.current_idx + self.timestep + 1
        next_rec = self.records[next_idx] if next_idx < len(self.records) else current

        reward_breakdown = self.reward_computer.compute_total(
            state=current["state"],
            action=action,
            prev_action=self.prev_action,
            next_state=next_rec["state"],
        )

        self.timestep += 1
        self.prev_action = action.value
        done = self.timestep >= self.max_steps

        return self._get_state(), reward_breakdown, done

    def _get_state(self) -> KernelState:
        """Read the current state from the dataset."""
        rec = self.records[self.current_idx + self.timestep]
        return KernelState(
            features=rec["state"],
            pid=rec["pid"],
            cpu=rec["cpu"],
            timestep=self.timestep,
            prev_action=self.prev_action,
        )

    def simulate(self, state_features: list, action_value: float) -> list:
        """Lightweight next-state lookup for reward_fn during GRPO.

        Finds the nearest recorded state in the dataset and returns
        its recorded next_state. This is a simple approximation;
        the World Model provides higher-fidelity simulation.
        """
        import numpy as np

        state_arr = np.array(state_features)
        best_dist = float("inf")
        best_next = state_features  # fallback

        # Sample a subset to keep this fast
        sample_size = min(500, len(self.records))
        indices = random.sample(range(len(self.records)), sample_size)

        for idx in indices:
            rec = self.records[idx]
            dist = float(np.linalg.norm(state_arr - np.array(rec["state"])))
            if dist < best_dist:
                best_dist = dist
                best_next = rec["next_state"]

        return best_next