Spaces:
Runtime error
Runtime error
Commit
Β·
7db5284
1
Parent(s):
70fb224
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
arena = [['β½', ' ' , 'π', ' ' ],
|
| 6 |
+
[' ' , ' ' , ' ' , 'π'],
|
| 7 |
+
[' ' , 'π', ' ' , ' ' ],
|
| 8 |
+
[' ' , ' ' , ' ' , 'π'],
|
| 9 |
+
[' ' , 'π', ' ' , 'π₯
']]
|
| 10 |
+
|
| 11 |
+
agent = 'β½'
|
| 12 |
+
opponent = 'π'
|
| 13 |
+
goal = 'π₯
'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Foolsball(object):
|
| 17 |
+
|
| 18 |
+
def __to_state__(self,row,col):
|
| 19 |
+
return row*self.n_cols + col
|
| 20 |
+
|
| 21 |
+
def __to_indices__(self, state):
|
| 22 |
+
row = state / self.n_rows
|
| 23 |
+
col = state % self.n_rows
|
| 24 |
+
return row,col
|
| 25 |
+
|
| 26 |
+
def __deserialize__(self,map:list,agent:str,opponent:str, goal:str):
|
| 27 |
+
self.n_rows = len(map)
|
| 28 |
+
self.n_cols = len(map[0])
|
| 29 |
+
self.n_states = self.n_rows * self.n_cols
|
| 30 |
+
self.map = np.asarray(map)
|
| 31 |
+
self.agent_repr = agent
|
| 32 |
+
self.opponent_repr = opponent
|
| 33 |
+
self.goal_repr = goal
|
| 34 |
+
|
| 35 |
+
self.init_state = None
|
| 36 |
+
self.goal = None
|
| 37 |
+
self.opponents = []
|
| 38 |
+
|
| 39 |
+
for row in range(self.n_rows):
|
| 40 |
+
for col in range(self.n_cols):
|
| 41 |
+
if map[row][col] == agent:
|
| 42 |
+
self.init_state = self.__to_state__(row,col)
|
| 43 |
+
self.map[row,col] = ' ' #Store state outside map
|
| 44 |
+
|
| 45 |
+
elif map[row][col] == agent:
|
| 46 |
+
self.opponents.append(self.__to_state__(row,col))
|
| 47 |
+
|
| 48 |
+
elif map[row][col] == goal:
|
| 49 |
+
self.goal = self.__to_state__(row,col)
|
| 50 |
+
|
| 51 |
+
assert self.init_state is not None, print(f"Map {map} does not specify an agent {agent} location")
|
| 52 |
+
assert self.goal is not None, print(f"Map {map} does not specify a goal {goal} location")
|
| 53 |
+
assert not self.opponents, print(f"Map {map} does not specify any opponents {opponent} location")
|
| 54 |
+
|
| 55 |
+
return self.init_state
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def __get_next_state_on_action__(self,state,action):
|
| 59 |
+
row, col = self.__to_indices__(state)
|
| 60 |
+
action_to_index_delta = {'n':[-1,0], 'e':[0,+1], 'w':[0,-1], 's':[+1,0]}
|
| 61 |
+
|
| 62 |
+
row_delta, col_delta = action_to_index_delta[action]
|
| 63 |
+
new_row , new_col = row+row_delta, col+col_delta
|
| 64 |
+
|
| 65 |
+
## Return current state if next state is invalid
|
| 66 |
+
if not(0<=new_row<self.n_rows) or not(0<=new_col<self.n_cols):
|
| 67 |
+
return state
|
| 68 |
+
|
| 69 |
+
return self.__to_state__(new_row, new_col)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def __init__(self,map,agent,opponent,goal):
|
| 73 |
+
self.state = self.__deserialize__(map,agent,opponent,goal)
|
| 74 |
+
self.done = False
|
| 75 |
+
self.actions = ['n','e','w','s']
|
| 76 |
+
self.transitions = self.__install_transition_table__()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def __install_transition_table__(self):
|
| 80 |
+
transitions = {}
|
| 81 |
+
for s in range(self.n_states):
|
| 82 |
+
for a in self.actions:
|
| 83 |
+
transitions[s] = self.__get_next_state_on_action__(s,a)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def reset(self):
|
| 87 |
+
self.state = self.init_state
|
| 88 |
+
self.done = False
|
| 89 |
+
return self.state
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def step(self,action):
|
| 93 |
+
assert not self.done, \
|
| 94 |
+
print(f'No actions supported in a terminal state {self.state}.'+
|
| 95 |
+
' Check the "done" flag before calling step()')
|
| 96 |
+
next_state = self.__get_next_state_on_action__(self.state, action)
|
| 97 |
+
|
| 98 |
+
## Transition rejected due to illegal action (move)
|
| 99 |
+
if next_state == self.state:
|
| 100 |
+
reward = -1
|
| 101 |
+
done = False
|
| 102 |
+
|
| 103 |
+
## Goal!
|
| 104 |
+
elif next_state == self.goal:
|
| 105 |
+
reward = +5
|
| 106 |
+
done = True
|
| 107 |
+
|
| 108 |
+
## Ran into opponent. Heavy penalty.
|
| 109 |
+
elif self.__to_indices__(next_state) in self.opponents:
|
| 110 |
+
reward = -5
|
| 111 |
+
done = True
|
| 112 |
+
|
| 113 |
+
## Made a safe and valid move. Penalize to take the shortest route.
|
| 114 |
+
else:
|
| 115 |
+
reward = -1
|
| 116 |
+
done = False
|
| 117 |
+
|
| 118 |
+
self.state, self.done = next_state, done
|
| 119 |
+
return next_state, reward, done
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def render(self):
|
| 123 |
+
canvas = ""
|
| 124 |
+
def print(str):
|
| 125 |
+
canvas.append(str)
|
| 126 |
+
canvas.append('\n')
|
| 127 |
+
|
| 128 |
+
for row in range(self.map.shape[0]):
|
| 129 |
+
for col in range(self.map.shape[1]):
|
| 130 |
+
|
| 131 |
+
if (row,col) == self.__to_indices__(self.goal):
|
| 132 |
+
if self.state == self.goal:
|
| 133 |
+
print(f'{self.agent_repr}{self.goal_repr} ')
|
| 134 |
+
else:
|
| 135 |
+
print(f' {self.goal_repr} ')
|
| 136 |
+
|
| 137 |
+
elif (row,col) in self.opponents:
|
| 138 |
+
print(f' {self.opponent_repr} ')
|
| 139 |
+
|
| 140 |
+
elif (row,col) == self.__to_indices__(self.state):
|
| 141 |
+
if self.state == self.goal:
|
| 142 |
+
print(f'{self.agent_repr}{self.goal_repr} ')
|
| 143 |
+
else:
|
| 144 |
+
print(f' {self.agent_repr} ')
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
print(' ')
|
| 148 |
+
|
| 149 |
+
print('\n')
|
| 150 |
+
return f"<HTML> <body> {canvas} </body> </HTML>"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
foolsball = Foolsball(arena, agent, opponent, goal)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def play(key):
|
| 157 |
+
key_to_action = {"Up":'n', "Down":'s', "Left":'w', "Right":'e'}
|
| 158 |
+
if key not in key_to_action:
|
| 159 |
+
return f"<HTML> <body> Invalid key {canvas} </body> </HTML>"
|
| 160 |
+
return foolsball.render()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
gr.Interface(fn=play,
|
| 166 |
+
inputs=gr.radio(["Up","Down","Left","Right"]),
|
| 167 |
+
outputs="html",
|
| 168 |
+
live=True).launch()
|
| 169 |
+
|