| | from mlagents_envs.logging_util import get_logger |
| | from typing import Deque, Dict |
| | from collections import deque |
| | from mlagents.trainers.ghost.trainer import GhostTrainer |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class GhostController: |
| | """ |
| | GhostController contains a queue of team ids. GhostTrainers subscribe to the GhostController and query |
| | it to get the current learning team. The GhostController cycles through team ids every 'swap_interval' |
| | which corresponds to the number of trainer steps between changing learning teams. |
| | The GhostController is a unique object and there can only be one per training run. |
| | """ |
| |
|
| | def __init__(self, maxlen: int = 10): |
| | """ |
| | Create a GhostController. |
| | :param maxlen: Maximum number of GhostTrainers allowed in this GhostController |
| | """ |
| |
|
| | |
| | |
| | self._queue: Deque[int] = deque(maxlen=maxlen) |
| | self._learning_team: int = -1 |
| | |
| | self._ghost_trainers: Dict[int, GhostTrainer] = {} |
| | |
| | self._changed_training_team = False |
| |
|
| | @property |
| | def get_learning_team(self) -> int: |
| | """ |
| | Returns the current learning team. |
| | :return: The learning team id |
| | """ |
| | return self._learning_team |
| |
|
| | def should_reset(self) -> bool: |
| | """ |
| | Whether or not team change occurred. Causes full reset in trainer_controller |
| | :return: The truth value of the team changing |
| | """ |
| | changed_team = self._changed_training_team |
| | if self._changed_training_team: |
| | self._changed_training_team = False |
| | return changed_team |
| |
|
| | def subscribe_team_id(self, team_id: int, trainer: GhostTrainer) -> None: |
| | """ |
| | Given a team_id and trainer, add to queue and trainers if not already. |
| | The GhostTrainer is used later by the controller to get ELO ratings of agents. |
| | :param team_id: The team_id of an agent managed by this GhostTrainer |
| | :param trainer: A GhostTrainer that manages this team_id. |
| | """ |
| | if team_id not in self._ghost_trainers: |
| | self._ghost_trainers[team_id] = trainer |
| | if self._learning_team < 0: |
| | self._learning_team = team_id |
| | else: |
| | self._queue.append(team_id) |
| |
|
| | def change_training_team(self, step: int) -> None: |
| | """ |
| | The current learning team is added to the end of the queue and then updated with the |
| | next in line. |
| | :param step: The step of the trainer for debugging |
| | """ |
| | self._queue.append(self._learning_team) |
| | self._learning_team = self._queue.popleft() |
| | logger.debug(f"Learning team {self._learning_team} swapped on step {step}") |
| | self._changed_training_team = True |
| |
|
| | |
| | |
| | |
| | |
| | def compute_elo_rating_changes(self, rating: float, result: float) -> float: |
| | """ |
| | Calculates ELO. Given the rating of the learning team and result. The GhostController |
| | queries the other GhostTrainers for the ELO of their agent that is currently being deployed. |
| | Note, this could be the current agent or a past snapshot. |
| | :param rating: Rating of the learning team. |
| | :param result: Win, loss, or draw from the perspective of the learning team. |
| | :return: The change in ELO. |
| | """ |
| | opponent_rating: float = 0.0 |
| | for team_id, trainer in self._ghost_trainers.items(): |
| | if team_id != self._learning_team: |
| | opponent_rating = trainer.get_opponent_elo() |
| | r1 = pow(10, rating / 400) |
| | r2 = pow(10, opponent_rating / 400) |
| |
|
| | summed = r1 + r2 |
| | e1 = r1 / summed |
| |
|
| | change = result - e1 |
| | for team_id, trainer in self._ghost_trainers.items(): |
| | if team_id != self._learning_team: |
| | trainer.change_opponent_elo(change) |
| |
|
| | return change |
| |
|