jtowarek commited on
Commit
60fe866
·
verified ·
1 Parent(s): 92dbd69

Delete train/trajectory.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train/trajectory.py +0 -206
train/trajectory.py DELETED
@@ -1,206 +0,0 @@
1
- """Trajectory collection for training data generation."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Any, Callable, Dict, List, Optional
7
-
8
- from env.models import GameAction, GameObservation, RoundResult
9
- from env.environment import KantEnvironment
10
- from constant_definitions.game_constants import EVAL_ZERO_FLOAT
11
-
12
-
13
- @dataclass
14
- class StepRecord:
15
- """A single step within an episode trajectory."""
16
-
17
- prompt: str
18
- completion: str
19
- action: str
20
- reward: float
21
- player_payoff: float
22
- opponent_payoff: float
23
- round_number: int
24
-
25
-
26
- @dataclass
27
- class EpisodeTrajectory:
28
- """Complete trajectory of one episode."""
29
-
30
- game: str
31
- strategy: str
32
- steps: List[StepRecord] = field(default_factory=list)
33
- episode_reward: float = EVAL_ZERO_FLOAT
34
- player_score: float = EVAL_ZERO_FLOAT
35
- opponent_score: float = EVAL_ZERO_FLOAT
36
- cooperation_rate: float = EVAL_ZERO_FLOAT
37
- rounds_played: int = int()
38
- metrics: Dict[str, float] = field(default_factory=dict)
39
-
40
-
41
- class TrajectoryCollector:
42
- """Runs episodes and collects trajectories for training.
43
-
44
- Parameters
45
- ----------
46
- env : KantEnvironment
47
- The game environment instance.
48
- agent : LLMAgent
49
- An agent with ``last_prompt`` / ``last_completion`` properties,
50
- callable with ``(GameObservation) -> GameAction``.
51
- reward_fn : callable, optional
52
- Function(player_score, opponent_score, cooperation_rate, rounds) -> float.
53
- step_reward_fn : callable, optional
54
- Function(player_payoff, opponent_payoff, payoff_min, payoff_max) -> float.
55
- """
56
-
57
- def __init__(
58
- self,
59
- env: KantEnvironment,
60
- agent: Any,
61
- reward_fn: Optional[Callable[..., float]] = None,
62
- step_reward_fn: Optional[Callable[..., float]] = None,
63
- ) -> None:
64
- self._env = env
65
- self._agent = agent
66
- self._reward_fn = reward_fn
67
- self._step_reward_fn = step_reward_fn
68
-
69
- def collect_episode(
70
- self,
71
- game: str,
72
- strategy: str = "tit_for_tat",
73
- opponent_fn: Optional[Callable] = None,
74
- ) -> EpisodeTrajectory:
75
- """Run a single episode and return its trajectory."""
76
- if opponent_fn is not None:
77
- obs = self._env.reset(game=game, opponent_fn=opponent_fn)
78
- else:
79
- obs = self._env.reset(game=game, strategy=strategy)
80
- steps: List[StepRecord] = []
81
-
82
- while not obs.done:
83
- action = self._agent(obs)
84
-
85
- # Capture prompt/completion from agent
86
- prompt = getattr(self._agent, "last_prompt", "")
87
- completion = getattr(self._agent, "last_completion", "")
88
-
89
- next_obs = self._env.step(action)
90
-
91
- # Compute step reward
92
- step_reward = EVAL_ZERO_FLOAT
93
- if self._step_reward_fn is not None and next_obs.last_round is not None:
94
- step_reward = self._step_reward_fn(
95
- next_obs.last_round.player_payoff,
96
- next_obs.last_round.opponent_payoff,
97
- EVAL_ZERO_FLOAT,
98
- EVAL_ZERO_FLOAT,
99
- )
100
-
101
- # Record step
102
- last_rnd = next_obs.last_round
103
- steps.append(StepRecord(
104
- prompt=prompt,
105
- completion=completion,
106
- action=action.action,
107
- reward=step_reward,
108
- player_payoff=(
109
- last_rnd.player_payoff if last_rnd is not None
110
- else EVAL_ZERO_FLOAT
111
- ),
112
- opponent_payoff=(
113
- last_rnd.opponent_payoff if last_rnd is not None
114
- else EVAL_ZERO_FLOAT
115
- ),
116
- round_number=next_obs.current_round,
117
- ))
118
- obs = next_obs
119
-
120
- # Compute cooperation rate (reusing tournament logic pattern)
121
- coop_rate = _compute_cooperation_rate(obs)
122
-
123
- # Compute episode reward
124
- ep_reward = EVAL_ZERO_FLOAT
125
- if self._reward_fn is not None:
126
- ep_reward = self._reward_fn(
127
- obs.player_score,
128
- obs.opponent_score,
129
- coop_rate,
130
- obs.current_round,
131
- )
132
-
133
- return EpisodeTrajectory(
134
- game=game,
135
- strategy=strategy,
136
- steps=steps,
137
- episode_reward=ep_reward,
138
- player_score=obs.player_score,
139
- opponent_score=obs.opponent_score,
140
- cooperation_rate=coop_rate,
141
- rounds_played=obs.current_round,
142
- )
143
-
144
- def collect_batch(
145
- self,
146
- games: List[str],
147
- strategies: Optional[List[str]] = None,
148
- episodes_per_pair: int = int(bool(True)),
149
- opponent_fn: Optional[Callable] = None,
150
- ) -> List[EpisodeTrajectory]:
151
- """Collect trajectories for all (game, strategy) combinations.
152
-
153
- If *opponent_fn* is provided, self-play mode is used: only
154
- games are iterated (strategies are ignored).
155
- """
156
- trajectories: List[EpisodeTrajectory] = []
157
- if opponent_fn is not None:
158
- for game in games:
159
- for _ep in range(episodes_per_pair):
160
- traj = self.collect_episode(
161
- game, opponent_fn=opponent_fn,
162
- )
163
- trajectories.append(traj)
164
- else:
165
- strats = strategies or ["tit_for_tat"]
166
- for game in games:
167
- for strategy in strats:
168
- for _ep in range(episodes_per_pair):
169
- traj = self.collect_episode(game, strategy)
170
- trajectories.append(traj)
171
- return trajectories
172
-
173
-
174
- # ---------------------------------------------------------------------------
175
- # Helpers
176
- # ---------------------------------------------------------------------------
177
-
178
- _COOPERATIVE_ACTIONS = frozenset({"cooperate", "stag", "dove"})
179
- _ECONOMIC_PREFIXES = frozenset({"offer", "invest", "contribute"})
180
-
181
- _ZERO = int()
182
- _ONE = int(bool(True))
183
- _TWO = _ONE + _ONE
184
-
185
-
186
- def _compute_cooperation_rate(obs: GameObservation) -> float:
187
- """Fraction of cooperative moves in an episode."""
188
- if not obs.history:
189
- return EVAL_ZERO_FLOAT
190
- total = len(obs.history)
191
- cooperative_count = _ZERO
192
- first_action = obs.history[_ZERO].player_action
193
- prefix = first_action.split("_")[_ZERO]
194
- is_economic = prefix in _ECONOMIC_PREFIXES
195
- if is_economic:
196
- median_idx = len(obs.available_actions) // _TWO
197
- for rnd in obs.history:
198
- act = rnd.player_action
199
- if act in obs.available_actions:
200
- if obs.available_actions.index(act) >= median_idx:
201
- cooperative_count += _ONE
202
- else:
203
- for rnd in obs.history:
204
- if rnd.player_action in _COOPERATIVE_ACTIONS:
205
- cooperative_count += _ONE
206
- return cooperative_count / total