Update modeling_jat.py
Browse files- modeling_jat.py +7 -0
modeling_jat.py
CHANGED
|
@@ -711,6 +711,7 @@ class JatModel(GPTNeoPreTrainedModel):
|
|
| 711 |
action_space: Union[spaces.Box, spaces.Discrete] = None,
|
| 712 |
reward: Optional[float] = None,
|
| 713 |
deterministic: bool = False,
|
|
|
|
| 714 |
):
|
| 715 |
# Get the maximum sequence length
|
| 716 |
max_length = self.config.max_position_embeddings // 2
|
|
@@ -804,6 +805,12 @@ class JatModel(GPTNeoPreTrainedModel):
|
|
| 804 |
# We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
|
| 805 |
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
|
| 806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
# Return the predicted action
|
| 808 |
if continuous_actions is not None:
|
| 809 |
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
|
|
|
|
| 711 |
action_space: Union[spaces.Box, spaces.Discrete] = None,
|
| 712 |
reward: Optional[float] = None,
|
| 713 |
deterministic: bool = False,
|
| 714 |
+
context_window: Optional[int] = None,
|
| 715 |
):
|
| 716 |
# Get the maximum sequence length
|
| 717 |
max_length = self.config.max_position_embeddings // 2
|
|
|
|
| 805 |
# We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
|
| 806 |
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
|
| 807 |
|
| 808 |
+
# Context window
|
| 809 |
+
if context_window is not None:
|
| 810 |
+
self._last_key_values = tuple(
|
| 811 |
+
tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
# Return the predicted action
|
| 815 |
if continuous_actions is not None:
|
| 816 |
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
|