| | |
| | |
| | |
| | |
| | @@ -2,10 +2,10 @@ import glob |
| | import json |
| | import subprocess |
| | |
| | -import wandb |
| | from accelerate import Accelerator |
| | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
| | |
| | +import wandb |
| | from gia.config import Arguments |
| | from gia.eval.utils import is_slurm_available |
| | |
| | |
| | |
| | |
| | |
| | @@ -1,3 +1,5 @@ |
| | +from typing import Optional |
| | + |
| | import torch |
| | |
| | from gia.config.arguments import Arguments |
| | @@ -5,11 +7,12 @@ from gia.model import GiaModel |
| | |
| | |
| | class Evaluator: |
| | - def __init__(self, args: Arguments, task: str) -> None: |
| | + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None: |
| | self.args = args |
| | self.task = task |
| | + self.mean_random = mean_random |
| | |
| | - @torch.no_grad() |
| | + @torch.inference_mode() |
| | def evaluate(self, model: GiaModel) -> float: |
| | return self._evaluate(model) |
| | |
| | |
| | |
| | |
| | |
| | @@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1): |
| | |
| | elif task_name.startswith("metaworld"): |
| | import gymnasium as gym |
| | - import metaworld |
| | |
| | env_id = TASK_TO_ENV_MAPPING[task_name] |
| | env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) |
| | |
| | |
| | |
| | |
| | @@ -54,7 +54,7 @@ class GiaAgent: |
| | self.action_space = action_space |
| | self.deterministic = deterministic |
| | self.device = next(model.parameters()).device |
| | - self._max_length = self.model.config.max_position_embeddings - 10 |
| | + self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this |
| | |
| | if isinstance(observation_space, spaces.Box): |
| | self._observation_key = "continuous_observations" |
| | @@ -75,6 +75,11 @@ class GiaAgent: |
| | ) -> Tuple[Tuple[Tensor, Tensor], ...]: |
| | return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) |
| | |
| | + def set_model(self, model: GiaModel) -> None: |
| | + self.model = model |
| | + self.device = next(model.parameters()).device |
| | + self._max_length = self.model.config.max_position_embeddings |
| | + |
| | def reset(self, num_envs: int = 1) -> None: |
| | if self.prompter is not None: |
| | prompts = self.prompter.generate_prompts(num_envs) |
| | |
| | |
| | |
| | |
| | @@ -1,7 +1,7 @@ |
| | import gym |
| | from gym.vector.vector_env import VectorEnv |
| | |
| | -from gia.eval.mappings import TASK_TO_ENV_MAPPING |
| | +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING |
| | from gia.eval.rl.rl_evaluator import RLEvaluator |
| | |
| | |
| | |
| | |
| | |
| | |
| | @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent |
| | |
| | |
| | class RLEvaluator(Evaluator): |
| | + def __init__(self, args, task): |
| | + super().__init__(args, task) |
| | + self.agent = GiaAgent() |
| | + |
| | def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? |
| | raise NotImplementedError |
| | |
| | |
| | |
| | |
| | |
| | @@ -929,8 +929,8 @@ |
| | }, |
| | "metaworld-assembly": { |
| | "expert": { |
| | - "mean": 311.29314618777823, |
| | - "std": 75.04282151450695 |
| | + "mean": 3523.81468486244, |
| | + "std": 63.22745220327798 |
| | }, |
| | "random": { |
| | "mean": 220.65601680730813, |
| |
|