File size: 2,246 Bytes
05a686e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from dataclasses import dataclass, field
from typing import List, Callable, Any, Optional, Dict


@dataclass
class StepRecord:
    step: int
    action_applied: str
    reward: float
    done: bool
    error: Optional[str] = None


@dataclass
class EpisodeResult:
    task_id: str
    steps_taken: int
    rewards: List[float]
    success: bool
    history: List[StepRecord] = field(default_factory=list)

    @property
    def total_reward(self) -> float:
        return sum(self.rewards)


class Worker:
    def run_episode(
        self,
        task_id: str,
        world: Any,
        get_action: Callable[[Any], Any],
        max_steps: int,
        grader: Any
    ) -> EpisodeResult:
        obs = world.reset(task=task_id)
        history: List[StepRecord] = []
        rewards: List[float] = []
        done = False
        
        for step in range(1, max_steps + 1):
            action = get_action(obs)
            
            error = None
            from server.validator import validate
            validation_error = validate(action, world.get_raw_state())
            
            if validation_error:
                history.append(StepRecord(
                    step=step,
                    action_applied="invalid_action",
                    reward=0.0,
                    done=False,
                    error=validation_error
                ))
                rewards.append(0.0)
                continue
            
            from server.executor import execute
            result = execute(action, world)
            
            reward = grader.grade(world.get_raw_state(), step, max_steps)
            done = grader.is_done(world.get_raw_state())
            
            history.append(StepRecord(
                step=step,
                action_applied=result.action_applied,
                reward=reward,
                done=done,
                error=None
            ))
            rewards.append(reward)
            obs = result.observation
            
            if done:
                break
        
        return EpisodeResult(
            task_id=task_id,
            steps_taken=len(history),
            rewards=rewards,
            success=done,
            history=history
        )