QuadOpt-RL / environment /observation_register.py
ropercha
Mesh environement
ba246bb
import pandas as pd
class ObservationRegistry:
def __init__(self, n_darts_selected, deep, lowest_value, highest_value):
self.n_darts = n_darts_selected
self.deep = deep
self.low = lowest_value
self.high = highest_value
self.df = pd.DataFrame(columns=["counts"])
self.df.index.name = "observations"
def encode(self, observation):
"""
Converts an observation into a tuple.
:param observation:
:return: the tuple ID of the observation
"""
ID = tuple(tuple(dart_surrounding) for dart_surrounding in observation)
return ID
def decode(self, ID):
"""
Reconstructs an observation from its ID.
:param ID: a tuple ID
:return: the observation matrix
"""
return [list(dart_surrounding) for dart_surrounding in ID]
def register_observation(self, observation) -> None:
"""
Adds an observation to the register and increments its counter.
:param observation:
"""
ID = self.encode(observation)
if self.df.empty:
self.df = pd.DataFrame({"counts": [1]}, index=[ID])
elif ID in self.df.index:
self.df.at[ID, "counts"] += 1
else:
new_row = pd.DataFrame({"counts": [1]}, index=[ID])
self.df = pd.concat([self.df, new_row])
def get_count(self, observation):
"""
Returns the number of times an observation has been seen by the agent.
:param observation:
:return: counts of observations
"""
ID = self.encode(observation)
return int(self.df.loc[ID])
def save(self, path):
if path.endswith(".csv"):
self.df.to_csv(path)
elif path.endswith(".parquet"):
self.df.to_parquet(path)
else:
print("Unsupported file type")
def load_counts(self, path):
if path.endswith(".csv"):
self.df = pd.read_csv(path)
elif path.endswith(".parquet"):
self.df = pd.read_parquet(path)
else:
print("Unsupported file type")