Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import numpy as np | |
| agent = 'β½' | |
| opponent = 'π' | |
| goal = 'π₯ ' | |
| arena = [['β½', ' ' , 'π', ' ' ], | |
| [' ' , ' ' , ' ' , 'π'], | |
| [' ' , 'π', ' ' , ' ' ], | |
| [' ' , ' ' , ' ' , 'π'], | |
| [' ' , 'π', ' ' , 'π₯ ']] | |
| class Foolsball(object): | |
| def __to_state__(self,row,col): | |
| """Convert from indices (row,col) to integer position.""" | |
| return row*self.n_cols + col | |
| def __to_indices__(self, state): | |
| """Convert from inteeger position to indices(row,col)""" | |
| row = state // self.n_cols | |
| col = state % self.n_cols | |
| return row,col | |
| def __deserialize__(self,map:list,agent:str,opponent:str, goal:str): | |
| """Convrt a string representation of a map into a 2D numpy array | |
| Param map: list of lists of strings representing the player, opponents and goal. | |
| Param agent: string representing the agent on the map | |
| Param opponent: string representing every instance of an opponent player | |
| Param goal: string representing the location of the goal on the map | |
| """ | |
| ## Capture dimensions and map. | |
| self.n_rows = len(map) | |
| self.n_cols = len(map[0]) | |
| self.n_states = self.n_rows * self.n_cols | |
| self.map = np.asarray(map) | |
| ## Store string representations for printing the map, etc. | |
| self.agent_repr = agent | |
| self.opponent_repr = opponent | |
| self.goal_repr = goal | |
| ## Find initial state, the desired goal state and the state of the opponents. | |
| self.init_state = None | |
| self.goal_state = None | |
| self.opponents_states = [] | |
| for row in range(self.n_rows): | |
| for col in range(self.n_cols): | |
| if map[row][col] == agent: | |
| # Store the initial state outside the map. | |
| # This helps in quickly resetting the game to the initial state and | |
| # also simplifies printing the map independent of the agent's state. | |
| self.init_state = self.__to_state__(row,col) | |
| self.map[row,col] = ' ' | |
| elif map[row][col] == opponent: | |
| self.opponents_states.append(self.__to_state__(row,col)) | |
| elif map[row][col] == goal: | |
| self.goal_state = self.__to_state__(row,col) | |
| assert self.init_state is not None, print(f"Map {map} does not specify an agent {agent} location") | |
| assert self.goal_state is not None, print(f"Map {map} does not specify a goal {goal} location") | |
| assert self.opponents_states, print(f"Map {map} does not specify any opponents {opponent} location") | |
| return self.init_state | |
| def __init__(self,map,agent,opponent,goal): | |
| """Spawn the world, create variables to track state and actions.""" | |
| # We just need to track the location of the agent (the ball) | |
| # Everything else is static and so a potential algorithm doesn't | |
| # have to look at it. The variable `done` flags terminal states. | |
| self.state = self.__deserialize__(map,agent,opponent,goal) | |
| self.done = False | |
| self.actions = ['n','e','w','s'] | |
| # Set up the rewards | |
| self.default_rewards = {'unmarked':-1, 'opponent':-5, 'outside':-1, 'goal':+5} | |
| self.set_rewards(self.default_rewards) | |
| def set_rewards(self,rewards): | |
| if not self.state == self.init_state: | |
| print('Warning: Setting reward while not in initial state! You may want to call reset() first.') | |
| for key in self.default_rewards: | |
| assert key in rewards, print(f'Key {key} missing from reward.') | |
| self.rewards = rewards | |
| def reset(self): | |
| """Reset the environment to its initial state.""" | |
| # There's really just two things we need to reset: the state, which should | |
| # be reset to the initial state, and the `done` flag which should be | |
| # cleared to signal that we are not in a terminal state anymore, even if we | |
| # were earlier. | |
| self.state = self.init_state | |
| self.done = False | |
| return self.state | |
| def __get_next_state_on_action__(self,state,action): | |
| """Return next state based on current state and action.""" | |
| row, col = self.__to_indices__(state) | |
| action_to_index_delta = {'n':[-1,0], 'e':[0,+1], 'w':[0,-1], 's':[+1,0]} | |
| row_delta, col_delta = action_to_index_delta[action] | |
| new_row , new_col = row+row_delta, col+col_delta | |
| ## Return current state if next state is invalid | |
| if not(0<=new_row<self.n_rows) or not(0<=new_col<self.n_cols): | |
| return state | |
| ## Construct state from new row and col and return it. | |
| return self.__to_state__(new_row, new_col) | |
| def __get_reward_for_transition__(self,state,next_state): | |
| """ Return the reward based on the transition from current state to next state. """ | |
| ## Transition rejected due to illegal action (move) | |
| if next_state == state: | |
| reward = self.rewards['outside'] | |
| ## Goal! | |
| elif next_state == self.goal_state: | |
| reward = self.rewards['goal'] | |
| ## Ran into opponent. | |
| elif next_state in self.opponents_states: | |
| reward = self.rewards['opponent'] | |
| ## Made a safe and valid move. | |
| else: | |
| reward = self.rewards['unmarked'] | |
| return reward | |
| def __is_terminal_state__(self, state): | |
| return (state == self.goal_state) or (state in self.opponents_states) | |
| def step(self,action): | |
| """Simulate state transition based on current state and action received.""" | |
| assert not self.done, \ | |
| print(f'You cannot call step() in a terminal state({self.state}). Check the "done" flag before calling step() to avoid this.') | |
| next_state = self.__get_next_state_on_action__(self.state, action) | |
| reward = self.__get_reward_for_transition__(self.state, next_state) | |
| done = self.__is_terminal_state__(next_state) | |
| self.state, self.done = next_state, done | |
| return next_state, reward, done | |
| def render(self, toconsole=True): | |
| """Pretty-print the environment and agent.""" | |
| ## Create a copy of the map and change data type to accomodate | |
| ## 3-character strings | |
| _map = np.array(self.map, dtype='<U3') | |
| ## Mark unoccupied positions with special symbol. | |
| ## And add extra spacing to align all columns. | |
| for row in range(_map.shape[0]): | |
| for col in range(_map.shape[1]): | |
| if _map[row,col] == ' ': | |
| _map[row,col] = ' + ' | |
| elif _map[row,col] == self.opponent_repr: | |
| _map[row,col] = self.opponent_repr + ' ' | |
| elif _map[row,col] == self.goal_repr: | |
| _map[row,col] = ' ' + self.goal_repr + ' ' | |
| ## If current state overlaps with the goal state or one of the opponents' | |
| ## states, susbstitute a distinct marker. | |
| if self.state == self.goal_state: | |
| r,c = self.__to_indices__(self.state) | |
| _map[r,c] = ' π ' | |
| elif self.state in self.opponents_states: | |
| r,c = self.__to_indices__(self.state) | |
| _map[r,c] = ' β ' | |
| else: | |
| r,c = self.__to_indices__(self.state) | |
| _map[r,c] = ' ' + self.agent_repr | |
| if toconsole: | |
| for row in range(_map.shape[0]): | |
| for col in range(_map.shape[1]): | |
| print(f' {_map[row,col]} ',end="") | |
| print('\n') | |
| if toconsole: | |
| print() | |
| return _map | |
| foolsball = Foolsball(arena, agent, opponent, goal) | |
| foolsball.reset() | |
| def play(key): | |
| key_to_action = {"Up":'n', "Down":'s', "Left":'w', "Right":'e', "Reset":'r'} | |
| if key not in key_to_action: | |
| return f"<HTML> <body> Invalid key {key} </body> </HTML>" | |
| act = key_to_action[key] | |
| game_over = foolsball.__is_terminal_state__(foolsball.state) | |
| body = "" | |
| if act in foolsball.actions: | |
| if not game_over: | |
| foolsball.step(act) | |
| map = foolsball.render(False) | |
| elif act == 'r': | |
| foolsball.reset() | |
| print() | |
| map = foolsball.render(False) | |
| if foolsball.__is_terminal_state__(foolsball.state): | |
| body += "<p>Game over!!!</p>" | |
| for row in range(map.shape[0]): | |
| body += "<p>" | |
| for col in range(map.shape[1]): | |
| body += f' {map[row,col]} ' | |
| body += "</p>" | |
| body += "<p></p>" | |
| return f"<HTML> <body> {body} </body> </HTML>" | |
| gr.Interface(fn=play, | |
| inputs=gr.Radio(["Up","Down","Left","Right"]), | |
| outputs="html", | |
| live=True).launch() | |