Spaces:
Build error
Build error
| from __future__ import annotations | |
| import random | |
| from typing import TYPE_CHECKING, Any, List, Union | |
| from . import visibility_registry as VisibilityRegistry | |
| from .base import BaseVisibility | |
| if TYPE_CHECKING: | |
| from agentverse.environments import BaseEnvironment | |
| class ClassroomVisibility(BaseVisibility): | |
| """ | |
| Visibility function for classroom, supports group discussion. | |
| Args: | |
| student_per_group: | |
| The number of students per group. | |
| num_discussion_turn: | |
| The number of turns for group discussion. | |
| grouping: | |
| The grouping information. If it is a string, then it should be a | |
| grouping method, options are ["random", "sequential"]. If it is a | |
| list of list of int, then it should be the grouping information. | |
| """ | |
| grouping: Union[str, List[List[int]]] | |
| student_per_group: int = 4 | |
| num_discussion_turn: int = 5 | |
| current_turn: int = 0 | |
| def update_visible_agents(self, environment: BaseEnvironment): | |
| # We turn on grouping mode when the professor launches a group discussion | |
| if len(environment.last_messages) == 1 and environment.last_messages[ | |
| 0 | |
| ].content.startswith("[GroupDiscuss]"): | |
| environment.rule_params["is_grouped"] = True | |
| # We randomly group the students | |
| environment.rule_params["groups"] = self.group_students(environment) | |
| # Update the receiver for each agent | |
| self.update_receiver(environment) | |
| else: | |
| # If now in grouping mode, then we check if the group discussion is over | |
| if environment.rule_params.get("is_grouped", False): | |
| self.current_turn += 1 | |
| if self.current_turn >= self.num_discussion_turn: | |
| self.reset() | |
| environment.rule_params["is_grouped"] = False | |
| environment.rule_params["is_grouped_ended"] = True | |
| self.update_receiver(environment, reset=True) | |
| def group_students(self, environment: BaseEnvironment) -> List[List[int]]: | |
| if isinstance(self.grouping, str): | |
| student_index = list(range(1, len(environment.agents))) | |
| result = [] | |
| if self.grouping == "random": | |
| random.shuffle(student_index) | |
| for i in range(0, len(student_index), self.student_per_group): | |
| result.append(student_index[i : i + self.student_per_group]) | |
| elif self.grouping == "sequential": | |
| for i in range(0, len(student_index), self.student_per_group): | |
| result.append(student_index[i : i + self.student_per_group]) | |
| else: | |
| raise ValueError(f"Unsupported grouping method {self.grouping}") | |
| return result | |
| else: | |
| # If the grouping information is provided, then we use it directly | |
| return self.grouping | |
| def update_receiver(self, environment: BaseEnvironment, reset=False): | |
| if reset: | |
| for agent in environment.agents: | |
| agent.set_receiver(set({"all"})) | |
| else: | |
| groups = environment.rule_params["groups"] | |
| for group in groups: | |
| group_name = set({environment.agents[i].name for i in group}) | |
| for agent_id in group: | |
| environment.agents[agent_id].set_receiver(group_name) | |
| def reset(self): | |
| self.current_turn = 0 | |