openenv / server /worker.py
MrShadowBlade's picture
Implement Kubernetes action classes and execution logic
57c06cb
from dataclasses import dataclass, field
from typing import List, Callable, Any, Optional, Dict
@dataclass
class StepRecord:
step: int
action_applied: str
reward: float
done: bool
error: Optional[str] = None
@dataclass
class EpisodeResult:
task_id: str
steps_taken: int
rewards: List[float]
success: bool
history: List[StepRecord] = field(default_factory=list)
@property
def total_reward(self) -> float:
return sum(self.rewards)
class Worker:
def run_episode(
self,
task_id: str,
world: Any,
get_action: Callable[[Any], Any],
max_steps: int,
grader: Any
) -> EpisodeResult:
obs = world.reset(task=task_id)
history: List[StepRecord] = []
rewards: List[float] = []
done = False
for step in range(1, max_steps + 1):
action = get_action(obs)
error = None
from server.validator import validate
validation_error = validate(action, world.get_raw_state())
if validation_error:
history.append(StepRecord(
step=step,
action_applied="invalid_action",
reward=0.0,
done=False,
error=validation_error
))
rewards.append(0.0)
continue
from server.executor import execute
result = execute(action, world)
reward = grader.grade(world.get_raw_state(), step, max_steps)
done = grader.is_done(world.get_raw_state())
history.append(StepRecord(
step=step,
action_applied=result.action_applied,
reward=reward,
done=done,
error=None
))
rewards.append(reward)
obs = result.observation
if done:
break
return EpisodeResult(
task_id=task_id,
steps_taken=len(history),
rewards=rewards,
success=done,
history=history
)