Spaces:
Running
on
Zero
Running
on
Zero
| ####################################################################### | |
| # Name: test_info_surfing.py | |
| # | |
| # - Runs robot in environment using Info Surfing Planner | |
| ####################################################################### | |
| import sys | |
| sys.modules['TRAINING'] = False # False = Inference Testing | |
| import copy | |
| import os | |
| import imageio | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from time import time | |
| from types import SimpleNamespace | |
| from skimage.transform import resize | |
| from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer | |
| from .env import Env | |
| from .test_parameter import * | |
| OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6} | |
| # color | |
| agentColor = (1, 0.2, 0.6) | |
| agentCommColor = (1, 0.6, 0.2) | |
| obstacleColor = (0., 0., 0.) | |
| targetNotFound = (0., 1., 0.) | |
| targetFound = (0.545, 0.27, 0.075) | |
| highestProbColor = (1., 0., 0.) | |
| highestUncertaintyColor = (0., 0., 1.) | |
| lowestProbColor = (1., 1., 1.) | |
| class ISEnv: | |
| """Custom Environment that follows gym interface""" | |
| metadata = {'render.modes': ['human']} | |
| def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None): | |
| self.global_step = global_step | |
| self.infoMap = None | |
| self.targetMap = None | |
| self.agents = [] | |
| self.targets = [] | |
| self.numAgents = numAgents | |
| self.found_target = [] | |
| self.shape = shape | |
| self.observationSize = observationSize | |
| self.sensorSize = sensorSize | |
| self.diag = diag | |
| self.communicateCircle = 11 | |
| self.distribs = [] | |
| self.mask = None | |
| self.finished = False | |
| self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] | |
| self.actionlist = [] | |
| self.IS_step = 0 | |
| self.save_image = save_image | |
| self.clip_seg_tta = clip_seg_tta | |
| self.perf_metrics = dict() | |
| self.steps_to_first_tgt = None | |
| self.steps_to_mid_tgt = None | |
| self.steps_to_last_tgt = None | |
| self.targets_found_on_path = [] | |
| self.step_since_tta = 0 | |
| self.IS_frame_files = [] | |
| self.bad_mask_init = False | |
| # define env | |
| self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True) | |
| # Overwrite state | |
| if self.clip_seg_tta is not None: | |
| self.clip_seg_tta.reset(sample_idx=self.global_step) | |
| # Override target positions in env | |
| self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] | |
| # Override segmentation mask | |
| if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "": | |
| score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name) | |
| print("score_mask_path: ", score_mask_path) | |
| if os.path.exists(score_mask_path): | |
| self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path) | |
| self.env.begin(self.env.map_start_position) | |
| else: | |
| print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path) | |
| self.bad_mask_init = True | |
| # Save clustered embeds from sat encoder | |
| if USE_CLIP_PREDS: | |
| self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer( | |
| k_min=1, | |
| k_max=8, | |
| k_avg_max=4, | |
| silhouette_threshold=0.15, | |
| relative_threshold=0.15, | |
| random_state=0, | |
| min_patch_size=5, | |
| n_smooth_iter=2, | |
| ignore_label=-1, | |
| plot=self.save_image, | |
| gifs_dir = GIFS_PATH | |
| ) | |
| # Generate kmeans clusters | |
| self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict( | |
| patch_embeds=self.clip_seg_tta.patch_embeds, | |
| map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]), | |
| ) | |
| if EXECUTE_TTA: | |
| print("Will execute TTA...") | |
| IS_info_map = copy.deepcopy(self.env.segmentation_info_mask) | |
| IS_agent_loc = copy.deepcopy(self.env.start_positions) | |
| IS_target_loc = copy.deepcopy(self.env.target_positions) | |
| state=[IS_info_map, IS_agent_loc, IS_target_loc] | |
| self.setWorld(state) | |
| def init_render(self): | |
| """ | |
| Call this once (e.g., in __init__ or just before the scenario loop) | |
| to initialize storage for agent paths and turn interactive plotting on. | |
| """ | |
| # Keep track of each agent's trajectory | |
| self.trajectories = [[] for _ in range(self.numAgents)] | |
| self.trajectories_upscaled = [[] for _ in range(self.numAgents)] | |
| # Turn on interactive mode so we can update the same figure repeatedly | |
| plt.ion() | |
| plt.figure(figsize=(6,6)) | |
| plt.title("Information Map with Agents, Targets, and Sensor Ranges") | |
| def record_positions(self): | |
| """ | |
| Call this after all agents have moved in a step (or whenever you want to update | |
| the trajectory). It appends the current positions of each agent to `self.trajectories`. | |
| """ | |
| for idx, agent in enumerate(self.agents): | |
| self.trajectories[idx].append((agent.row, agent.col)) | |
| self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col]) | |
| def render(self, episode_num, step_num): | |
| """ | |
| Renders the current state in a single matplotlib plot. | |
| Ensures consistent image size for GIF generation. | |
| """ | |
| # Completely reset the figure to avoid leftover state | |
| plt.close('all') | |
| fig = plt.figure(figsize=(6.4, 4.8), dpi=100) | |
| ax = fig.add_subplot(111) | |
| # Plot the information map | |
| ax.imshow(self.infoMap, origin='lower', cmap='gray') | |
| # Show agent positions and their trajectories | |
| for idx, agent in enumerate(self.agents): | |
| positions = self.trajectories[idx] | |
| if len(positions) > 1: | |
| rows = [p[0] for p in positions] | |
| cols = [p[1] for p in positions] | |
| ax.plot(cols, rows, linewidth=1) | |
| ax.scatter(agent.col, agent.row, marker='o', s=50) | |
| # Plot target locations | |
| for t in self.targets: | |
| color = 'green' if np.isnan(t.time_found) else 'red' | |
| ax.scatter(t.col, t.row, marker='x', s=100, color=color) | |
| # Title and axis formatting | |
| ax.set_title(f"Step: {self.IS_step}") | |
| ax.invert_yaxis() | |
| # Create output folder if it doesn't exist | |
| if not os.path.exists(GIFS_PATH): | |
| os.makedirs(GIFS_PATH) | |
| # Save the frame with consistent canvas | |
| frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png' | |
| plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1) | |
| self.IS_frame_files.append(frame_path) | |
| # Cleanup | |
| plt.close(fig) | |
| def setWorld(self, state=None): | |
| """ | |
| 1. empty all the element | |
| 2. create the new episode | |
| """ | |
| if state is not None: | |
| self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T) | |
| agents = [] | |
| self.numAgents = len(state[1]) | |
| for a in range(1, self.numAgents + 1): | |
| abs_pos = state[1].pop(0) | |
| abs_pos = np.array(abs_pos) | |
| row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos)) | |
| agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap), | |
| uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents)) | |
| self.agents = agents | |
| targets, n_targets = [], 1 | |
| for t in range(len(state[2])): | |
| abs_pos = state[2].pop(0) | |
| abs_pos = np.array(abs_pos) | |
| row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos) | |
| targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan)) | |
| n_targets = n_targets + 1 | |
| self.targets = targets | |
| def extractObservation(self, agent): | |
| """ | |
| Extract observations from information map | |
| """ | |
| transform_row = self.observationSize // 2 - agent.row | |
| transform_col = self.observationSize // 2 - agent.col | |
| observation_layers = np.zeros((1, self.observationSize, self.observationSize)) | |
| min_row = max((agent.row - self.observationSize // 2), 0) | |
| max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0]) | |
| min_col = max((agent.col - self.observationSize // 2), 0) | |
| max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1]) | |
| observation = np.full((self.observationSize, self.observationSize), 0.) | |
| infoMap = np.full((self.observationSize, self.observationSize), 0.) | |
| densityMap = np.full((self.observationSize, self.observationSize), 0.) | |
| infoMap[(min_row + transform_row):(max_row + transform_row), | |
| (min_col + transform_col):(max_col + transform_col)] = self.infoMap[ | |
| min_row:max_row, min_col:max_col] | |
| observation_layers[0] = infoMap | |
| return observation_layers | |
| def listNextValidActions(self, agent_id, prev_action=0): | |
| """ | |
| No movement: 0 | |
| North (-1,0): 1 | |
| East (0,1): 2 | |
| South (1,0): 3 | |
| West (0,-1): 4 | |
| """ | |
| available_actions = [0] | |
| agent = self.agents[agent_id - 1] | |
| MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)] | |
| size = 4 + self.diag * 4 | |
| for action in range(size): | |
| out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \ | |
| or agent.row + MOVES[action][0] < 0\ | |
| or agent.col + MOVES[action][1] >= self.shape[1] \ | |
| or agent.col + MOVES[action][1] < 0 | |
| if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]): | |
| available_actions.append(action + 1) | |
| return np.array(available_actions) | |
| def executeAction(self, agentID, action, timeStep): | |
| """ | |
| No movement: 0 | |
| North (-1,0): 1 | |
| East (0,1): 2 | |
| South (1,0): 3 | |
| West (0,-1): 4 | |
| LeftUp (-1,-1) : 5 | |
| RightUP (-1,1) :6 | |
| RightDown (1,1) :7 | |
| RightLeft (1,-1) :8 | |
| """ | |
| agent = self.agents[agentID - 1] | |
| origLoc = agent.getLocation() | |
| if (action >= 1) and (action <= 8): | |
| agent.move(action) | |
| row, col = agent.getLocation() | |
| # If the move is not valid, roll it back | |
| if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]): | |
| self.updateInfoCheckTarget(agentID, timeStep, origLoc) | |
| return 0 | |
| elif action == 0: | |
| self.updateInfoCheckTarget(agentID, timeStep, origLoc) | |
| return 0 | |
| else: | |
| print("INVALID ACTION: {}".format(action)) | |
| sys.exit() | |
| newLoc = agent.getLocation() | |
| self.updateInfoCheckTarget(agentID, timeStep, origLoc) | |
| return action | |
| def updateInfoCheckTarget(self, agentID, timeStep, origLoc): | |
| """ | |
| update the self.infoMap and check whether the agent has found a target | |
| """ | |
| agent = self.agents[agentID - 1] | |
| transform_row = self.sensorSize // 2 - agent.row | |
| transform_col = self.sensorSize // 2 - agent.col | |
| min_row = max((agent.row - self.sensorSize // 2), 0) | |
| max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0]) | |
| min_col = max((agent.col - self.sensorSize // 2), 0) | |
| max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1]) | |
| for t in self.targets: | |
| if (t.row == agent.row) and (t.col == agent.col): | |
| t.updateFound(timeStep) | |
| self.found_target.append(t) | |
| t.status = True | |
| self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 | |
| def updateInfoEntireTrajectory(self, agentID): | |
| """ | |
| update the self.infoMap and check whether the agent has found a target | |
| """ | |
| traj = self.trajectories[agentID - 1] | |
| for (row,col) in traj: | |
| min_row = max((row - self.sensorSize // 2), 0) | |
| max_row = min((row + self.sensorSize // 2 + 1), self.shape[0]) | |
| min_col = max((col - self.sensorSize // 2), 0) | |
| max_col = min((col + self.sensorSize // 2 + 1), self.shape[1]) | |
| self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 | |
| # Execute one time step within the environment | |
| def step(self, agentID, action, timeStep): | |
| """ | |
| the agents execute the actions | |
| No movement: 0 | |
| North (-1,0): 1 | |
| East (0,1): 2 | |
| South (1,0): 3 | |
| West (0,-1): 4 | |
| """ | |
| assert (agentID > 0) | |
| self.executeAction(agentID, action, timeStep) | |
| def observe(self, agentID): | |
| assert (agentID > 0) | |
| vectorObs = self.extractObservation(self.agents[agentID - 1]) | |
| return [vectorObs] | |
| def check_finish(self): | |
| if TERMINATE_ON_TGTS_FOUND: | |
| found_status = [t.time_found for t in self.targets] | |
| d = False | |
| if np.isnan(found_status).sum() == 0: | |
| d = True | |
| return d | |
| else: | |
| return False | |
| def gradVec(self, observation, agent): | |
| a = observation[0] | |
| # Make info & unc cells with low value as 0 | |
| a[a < 0.0002] = 0.0 | |
| # Center square from 11x11 | |
| a_11x11 = a[4:7, 4:7] | |
| m_11x11 = np.array((a_11x11)) | |
| # Center square from 9x9 | |
| a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False) | |
| a_9x9 = a_9x9[3:6, 3:6] | |
| m_9x9 = np.array((a_9x9)) | |
| # Center square from 6x6 | |
| a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False) | |
| a_6x6 = a_6x6[1:4, 1:4] | |
| m_6x6 = np.array((a_6x6)) | |
| # Center square from 3x3 | |
| a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False) | |
| m_3x3 = np.array((a_3x3)) | |
| # Merging multiScales with weights | |
| m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25 | |
| a = m | |
| adx, ady = np.gradient(a) | |
| den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]])) | |
| if (den != 0) and (not np.isnan(den)): | |
| infovec = np.array([adx[1, 1], ady[1, 1]]) / den | |
| else: | |
| infovec = 0 | |
| agentvec = [] | |
| if len(agentvec) == 0: | |
| den = np.linalg.norm(infovec) | |
| if (den != 0) and (not np.isnan(den)): | |
| direction = infovec / den | |
| else: | |
| direction = self.action_vects[np.random.randint(4 + self.diag * 4)] | |
| else: | |
| den = np.linalg.norm(np.mean(agentvec, 0)) | |
| if (den != 0) and (not np.isnan(den)): | |
| agentvec = np.mean(agentvec, 0) / den | |
| else: | |
| agentvec = 0 | |
| den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec) | |
| if (den != 0) and (not np.isnan(den)): | |
| direction = (0.6 * infovec + 0.4 * agentvec) / den | |
| else: | |
| direction = self.action_vects[np.random.randint(4 + self.diag * 4)] | |
| action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] | |
| actionid = np.argmax([np.dot(direction, a) for a in action_vec]) | |
| actionid = self.best_valid_action(actionid, agent, direction) | |
| return actionid | |
| def best_valid_action(self, actionid, agent, direction): | |
| if len(self.actionlist) > 1: | |
| if self.action_invalid(actionid, agent): | |
| action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] | |
| actionid = np.array([np.dot(direction, a) for a in action_vec]) | |
| actionid = actionid.argsort() | |
| pi = 3 + self.diag*4 | |
| while self.action_invalid(actionid[pi], agent) and pi >= 0: | |
| pi -= 1 | |
| if pi == -1: | |
| return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] | |
| elif actionid[pi] == 0: | |
| return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] | |
| else: | |
| return actionid[pi] | |
| return actionid | |
| def action_invalid(self, action, agent): | |
| # Going back to the previous cell is disabled | |
| if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]: | |
| return True | |
| # Move N,E,S,W | |
| if (action >= 1) and (action <= 8): | |
| agent = self.agents[agent - 1] | |
| agent.move(action) | |
| row, col = agent.getLocation() | |
| # If the move is not valid, roll it back | |
| if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])): | |
| agent.reverseMove(action) | |
| return True | |
| agent.reverseMove(action) | |
| return False | |
| return False | |
| def step_all_parallel(self): | |
| actions = [] | |
| reward = 0 | |
| # Decide actions for each agent | |
| for agent_id in range(1, self.numAgents + 1): | |
| o = self.observe(agent_id) | |
| actions.append(self.gradVec(o[0], agent_id)) | |
| self.actionlist.append(actions) | |
| # Execute those actions | |
| for agent_id in range(1, self.numAgents + 1): | |
| self.step(agent_id, actions[agent_id - 1], self.IS_step) | |
| # Record for visualization | |
| self.record_positions() | |
| def is_scenario(self, max_step=512, episode_number=0): | |
| # Return all metrics as None if faulty mask init | |
| if self.bad_mask_init: | |
| self.perf_metrics['tax'] = None | |
| self.perf_metrics['travel_dist'] = None | |
| self.perf_metrics['travel_steps'] = None | |
| self.perf_metrics['steps_to_first_tgt'] = None | |
| self.perf_metrics['steps_to_mid_tgt'] = None | |
| self.perf_metrics['steps_to_last_tgt'] = None | |
| self.perf_metrics['explored_rate'] = None | |
| self.perf_metrics['targets_found'] = None | |
| self.perf_metrics['targets_total'] = None | |
| self.perf_metrics['kmeans_k'] = None | |
| self.perf_metrics['tgts_gt_score'] = None | |
| self.perf_metrics['clip_inference_time'] = None | |
| self.perf_metrics['tta_time'] = None | |
| self.perf_metrics['success_rate'] = None | |
| return | |
| eps_start = time() | |
| self.IS_step = 0 | |
| self.finished = False | |
| reward = 0 | |
| # Initialize the rendering just once before the loop | |
| self.init_render() | |
| self.record_positions() | |
| # Initial Setup | |
| if LOAD_AVS_BENCH and USE_CLIP_PREDS: | |
| if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims | |
| heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) | |
| self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) | |
| unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) | |
| self.infoMap = copy.deepcopy(heatmap) | |
| print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) | |
| else: | |
| self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) | |
| self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap) | |
| self.targets_found_on_path.append(self.env.num_new_targets_found) | |
| while self.IS_step < max_step and not self.check_finish(): | |
| self.step_all_parallel() | |
| self.IS_step += 1 | |
| # Render after each step | |
| if self.save_image: | |
| self.render(episode_num=self.global_step, step_num=self.IS_step) | |
| # Update in env | |
| next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)] | |
| dist_list = [0 for _ in range(self.numAgents)] | |
| travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories] | |
| self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) | |
| self.targets_found_on_path.append(self.env.num_new_targets_found) | |
| # TTA Update via Poisson Test (with KMeans clustering stats) | |
| robot_id = 0 # Assume 1 agent for now | |
| robot_traj = self.trajectories[robot_id] | |
| if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA: | |
| flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))] | |
| robot = SimpleNamespace( | |
| trajectory_coords=flat_traj_coords, | |
| targets_found_on_path=self.targets_found_on_path | |
| ) | |
| self.poisson_tta_update(robot, self.global_step, self.IS_step) | |
| self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T) | |
| self.updateInfoEntireTrajectory(robot_id) | |
| # Update metrics | |
| self.log_metrics(step=self.IS_step-1) | |
| ### Save a frame to generate gif of robot trajectories ### | |
| if self.save_image: | |
| robots_route = [ ([], []) ] # Assume 1 robot | |
| for point in self.trajectories_upscaled[robot_id]: | |
| robots_route[robot_id][0].append(point[0]) | |
| robots_route[robot_id][1].append(point[1]) | |
| if not os.path.exists(GIFS_PATH): | |
| os.makedirs(GIFS_PATH) | |
| if LOAD_AVS_BENCH: | |
| sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0] | |
| self.env.plot_env( | |
| self.global_step, | |
| GIFS_PATH, | |
| self.IS_step-1, | |
| max(travel_dist_list), | |
| robots_route, | |
| img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st | |
| sat_path_override=self.clip_seg_tta.imo_path, | |
| msk_name_override=self.clip_seg_tta.species_name, | |
| sound_id_override=sound_id_override, | |
| ) | |
| else: | |
| self.env.plot_env( | |
| self.global_step, | |
| GIFS_PATH, | |
| self.IS_step-1, | |
| max(travel_dist_list), | |
| robots_route | |
| ) | |
| # Log metrics | |
| if LOAD_AVS_BENCH: | |
| tax = Path(self.clip_seg_tta.gt_mask_name).stem | |
| self.perf_metrics['tax'] = " ".join(tax.split("_")[1:]) | |
| else: | |
| self.perf_metrics['tax'] = None | |
| travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories] | |
| self.perf_metrics['travel_dist'] = max(travel_distances) | |
| self.perf_metrics['travel_steps'] = self.IS_step | |
| self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt | |
| self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt | |
| self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt | |
| self.perf_metrics['explored_rate'] = self.env.explored_rate | |
| self.perf_metrics['targets_found'] = self.env.targets_found_rate | |
| self.perf_metrics['targets_total'] = len(self.env.target_positions) | |
| if USE_CLIP_PREDS: | |
| self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k | |
| self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score | |
| self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time | |
| self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time | |
| else: | |
| self.perf_metrics['kmeans_k'] = None | |
| self.perf_metrics['tgts_gt_score'] = None | |
| self.perf_metrics['clip_inference_time'] = None | |
| self.perf_metrics['tta_time'] = None | |
| if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0: | |
| self.perf_metrics['success_rate'] = True | |
| else: | |
| self.perf_metrics['success_rate'] = self.env.check_done()[0] | |
| # save gif | |
| if self.save_image: | |
| path = GIFS_PATH | |
| self.make_gif(path, self.global_step) | |
| print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC) | |
| def asStride(self, arr, sub_shape, stride): | |
| """ | |
| Get a strided sub-matrices view of an ndarray. | |
| See also skimage.util.shape.view_as_windows() | |
| """ | |
| s0, s1 = arr.strides[:2] | |
| m1, n1 = arr.shape[:2] | |
| m2, n2 = sub_shape | |
| view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:] | |
| strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:] | |
| subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides) | |
| return subs | |
| def pooling(self, mat, ksize, stride=None, method='max', pad=False): | |
| """ | |
| Overlapping pooling on 2D or 3D data. | |
| <mat>: ndarray, input array to pool. | |
| <ksize>: tuple of 2, kernel size in (ky, kx). | |
| <stride>: tuple of 2 or None, stride of pooling window. | |
| If None, same as <ksize> (non-overlapping pooling). | |
| <method>: str, 'max for max-pooling, | |
| 'mean' for mean-pooling. | |
| <pad>: bool, pad <mat> or not. If no pad, output has size | |
| (n-f)//s+1, n being <mat> size, f being kernel size, s stride. | |
| if pad, output has size ceil(n/s). | |
| Return <result>: pooled matrix. | |
| """ | |
| m, n = mat.shape[:2] | |
| ky, kx = ksize | |
| if stride is None: | |
| stride = (ky, kx) | |
| sy, sx = stride | |
| _ceil = lambda x, y: int(np.ceil(x/float(y))) | |
| if pad: | |
| ny = _ceil(m,sy) | |
| nx = _ceil(n,sx) | |
| size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:] | |
| mat_pad = np.full(size,np.nan) | |
| mat_pad[:m,:n,...] = mat | |
| else: | |
| mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...] | |
| view = self.asStride(mat_pad,ksize,stride) | |
| if method == 'max': | |
| result = np.nanmax(view,axis=(2,3)) | |
| else: | |
| result = np.nanmean(view,axis=(2,3)) | |
| return result | |
| def compute_travel_distance(self, trajectory): | |
| distance = 0.0 | |
| for i in range(1, len(trajectory)): | |
| # Convert the tuple positions to numpy arrays for easy computation. | |
| prev_pos = np.array(trajectory[i-1]) | |
| curr_pos = np.array(trajectory[i]) | |
| # Euclidean distance between consecutive positions. | |
| distance += np.linalg.norm(curr_pos - prev_pos) | |
| return distance | |
| ################################################################################ | |
| # SPPP Related Fns | |
| ################################################################################ | |
| def log_metrics(self, step): | |
| # Update tgt found metrics | |
| if self.steps_to_first_tgt is None and self.env.num_targets_found == 1: | |
| self.steps_to_first_tgt = step + 1 | |
| if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2): | |
| self.steps_to_mid_tgt = step + 1 | |
| if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions): | |
| self.steps_to_last_tgt = step + 1 | |
| def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH): | |
| """ | |
| Transpose a flat index from an ``H×W`` grid to the equivalent | |
| position in the ``W×H`` transposed grid while **keeping the result | |
| in 1-D**. | |
| """ | |
| # --- Safety check to catch out-of-range indices --- | |
| assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})" | |
| # Original (row, col) | |
| row, col = divmod(idx, W) | |
| # After transpose these coordinates swap | |
| row_T, col_T = col, row | |
| # Flatten back into 1-D (row-major) for the W×H grid | |
| return row_T * H + col_T | |
| def poisson_tta_update(self, robot, episode, step): | |
| # Generate Kmeans Clusters Stats | |
| # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size | |
| if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: | |
| # High-res remap via pixel coordinates preserves exact neighbourhood | |
| filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory( | |
| robot.trajectory_coords, | |
| self.env.target_positions, | |
| old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH), | |
| full_dims=(512, 512), | |
| new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]) | |
| ) | |
| else: | |
| filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] | |
| filt_targets_found_on_path = robot.targets_found_on_path | |
| region_stats_dict = self.kmeans_clusterer.compute_region_statistics( | |
| self.kmeans_sat_embeds_clusters, | |
| self.clip_seg_tta.heatmap_unnormalized, | |
| filt_traj_coords, | |
| episode_num=episode, | |
| step_num=step | |
| ) | |
| # Prep & execute TTA | |
| self.step_since_tta += 1 | |
| if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: | |
| num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1] | |
| pos_sample_weight_scale, neg_sample_weight_scale = [], [] | |
| for i, sample_loc in enumerate(filt_traj_coords): | |
| label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc) | |
| num_patches = region_stats_dict[label]['num_patches'] | |
| patches_visited = region_stats_dict[label]['patches_visited'] | |
| expectation = region_stats_dict[label]['expectation'] | |
| # Exponent like focal loss to wait for more samples before confidently decreasing | |
| pos_weight = 4.0 | |
| neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) | |
| pos_sample_weight_scale.append(pos_weight) | |
| neg_sample_weight_scale.append(neg_weight) | |
| # Adaptative LR (as samples increase, increase LR to fit more datapoints) | |
| adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells) | |
| # TTA Update | |
| self.clip_seg_tta.execute_tta( | |
| filt_traj_coords, | |
| filt_targets_found_on_path, | |
| tta_steps=NUM_TTA_STEPS, | |
| lr=adaptive_lr, | |
| pos_sample_weight=pos_sample_weight_scale, | |
| neg_sample_weight=neg_sample_weight_scale, | |
| reset_weights=RESET_WEIGHTS | |
| ) | |
| if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims | |
| heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) | |
| self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) | |
| unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) | |
| print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) | |
| else: | |
| self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) | |
| self.step_since_tta = 0 | |
| def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)): | |
| heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear | |
| mode='reflect', anti_aliasing=True) | |
| coords = self.env.graph_generator.grid_coords # (N, N, 2) | |
| rows, cols = coords[...,1], coords[...,0] | |
| heatmap_resized = heatmap_large[rows, cols] | |
| heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0]) | |
| return heatmap_resized | |
| def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)): | |
| """ | |
| 1) Upsample via nearest‐neighbor to full_dims | |
| 2) Sample back down to your graph grid using grid_coords | |
| """ | |
| # 1) Upsample with nearest‐neighbor, preserving integer labels | |
| up = resize( | |
| labelmap, | |
| full_dims, | |
| order=0, # nearest‐neighbor | |
| mode='edge', # padding mode | |
| preserve_range=True, # don't normalize labels | |
| anti_aliasing=False # must be False for labels | |
| ).astype(labelmap.dtype) # back to original integer dtype | |
| # 2) Downsample via your precomputed grid coords (N×N×2) | |
| coords = self.env.graph_generator.grid_coords # shape (N, N, 2) | |
| rows = coords[...,1].astype(int) | |
| cols = coords[...,0].astype(int) | |
| small = up[rows, cols] # shape (N, N) | |
| small = small.reshape(new_dims[0], new_dims[1]) | |
| return small | |
| def scale_trajectory(self, | |
| flat_indices, | |
| targets, | |
| old_dims=(17, 17), | |
| full_dims=(512, 512), | |
| new_dims=(24, 24)): | |
| """ | |
| Args: | |
| flat_indices: list of ints in [0..old_H*old_W-1] | |
| targets: list of (y_pix, x_pix) in [0..full_H-1] | |
| old_dims: (old_H, old_W) | |
| full_dims: (full_H, full_W) | |
| new_dims: (new_H, new_W) | |
| Returns: | |
| new_flat_traj: list of unique flattened indices in new_H×new_W | |
| counts: list of ints, same length as new_flat_traj | |
| """ | |
| old_H, old_W = old_dims | |
| full_H, full_W = full_dims | |
| new_H, new_W = new_dims | |
| # 1) bin targets into new grid | |
| cell_h_new = full_H / new_H | |
| cell_w_new = full_W / new_W | |
| grid_counts = [[0]*new_W for _ in range(new_H)] | |
| for x_pix, y_pix in targets: # note (x, y) order as in original implementation | |
| i_t = min(int(y_pix / cell_h_new), new_H - 1) | |
| j_t = min(int(x_pix / cell_w_new), new_W - 1) | |
| grid_counts[i_t][j_t] += 1 | |
| # 2) Walk the trajectory indices and project each old cell's *entire | |
| # pixel footprint* onto the finer 24×24 grid. | |
| cell_h_full = full_H / old_H | |
| cell_w_full = full_W / old_W | |
| seen = set() | |
| new_flat_traj = [] | |
| for node_idx in flat_indices: | |
| if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords): | |
| continue | |
| coord_xy = self.env.graph_generator.node_coords[node_idx] | |
| try: | |
| row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy) | |
| except Exception: | |
| continue | |
| # Bounding box of the old cell in full-resolution pixel space | |
| y0 = row_old * cell_h_full | |
| y1 = (row_old + 1) * cell_h_full | |
| x0 = col_old * cell_w_full | |
| x1 = (col_old + 1) * cell_w_full | |
| # Which new-grid rows & cols overlap? (inclusive ranges) | |
| i_start = max(0, min(int(y0 / cell_h_new), new_H - 1)) | |
| i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1)) | |
| j_start = max(0, min(int(x0 / cell_w_new), new_W - 1)) | |
| j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1)) | |
| for ii in range(i_start, i_end + 1): | |
| for jj in range(j_start, j_end + 1): | |
| f_new = ii * new_W + jj | |
| if f_new not in seen: | |
| seen.add(f_new) | |
| new_flat_traj.append(f_new) | |
| # 3) annotate counts | |
| counts = [] | |
| for f in new_flat_traj: | |
| i_new, j_new = divmod(f, new_W) | |
| counts.append(grid_counts[i_new][j_new]) | |
| return new_flat_traj, counts | |
| ################################################################################ | |
| def make_gif(self, path, n): | |
| """ Generate a gif given list of images """ | |
| with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', | |
| fps=5) as writer: | |
| for frame in self.env.frame_files: | |
| image = imageio.imread(frame) | |
| writer.append_data(image) | |
| print('gif complete\n') | |
| # Remove files | |
| for filename in self.env.frame_files[:-1]: | |
| os.remove(filename) | |
| # For KMeans gif | |
| if LOAD_AVS_BENCH and USE_CLIP_PREDS: | |
| with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I', | |
| fps=5) as writer: | |
| for frame in self.kmeans_clusterer.kmeans_frame_files: | |
| image = imageio.imread(frame) | |
| writer.append_data(image) | |
| print('Kmeans Clusterer gif complete\n') | |
| # Remove files | |
| for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]: | |
| os.remove(filename) | |
| # IS gif | |
| with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I', | |
| fps=5) as writer: | |
| for frame in self.IS_frame_files: | |
| image = imageio.imread(frame) | |
| writer.append_data(image) | |
| print('Kmeans Clusterer gif complete\n') | |
| # Remove files | |
| for filename in self.IS_frame_files[:-1]: | |
| os.remove(filename) | |
| ################################################################################ | |
| class Agent: | |
| def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8): | |
| self.ID = ID | |
| self.row = row | |
| self.col = col | |
| self.numAgents = numAgents | |
| self.sensorSize = sensorSize | |
| def setLocation(self, row, col): | |
| self.row = row | |
| self.col = col | |
| def getLocation(self): | |
| return [self.row, self.col] | |
| def move(self, action): | |
| """ | |
| No movement: 0 | |
| North (-1,0): 1 | |
| East (0,1): 2 | |
| South (1,0): 3 | |
| West (0,-1): 4 | |
| LeftUp (-1,-1) : 5 | |
| RightUP (-1,1) :6 | |
| RightDown (1,1) :7 | |
| RightLeft (1,-1) :8 | |
| check valid action of the agent. be sure not to be out of the boundary | |
| """ | |
| if action == 0: | |
| return 0 | |
| elif action == 1: | |
| self.row -= 1 | |
| elif action == 2: | |
| self.col += 1 | |
| elif action == 3: | |
| self.row += 1 | |
| elif action == 4: | |
| self.col -= 1 | |
| elif action == 5: | |
| self.row -= 1 | |
| self.col -= 1 | |
| elif action == 6: | |
| self.row -= 1 | |
| self.col += 1 | |
| elif action == 7: | |
| self.row += 1 | |
| self.col += 1 | |
| elif action == 8: | |
| self.row += 1 | |
| self.col -= 1 | |
| def reverseMove(self, action): | |
| if action == 0: | |
| return 0 | |
| elif action == 1: | |
| self.row += 1 | |
| elif action == 2: | |
| self.col -= 1 | |
| elif action == 3: | |
| self.row -= 1 | |
| elif action == 4: | |
| self.col += 1 | |
| elif action == 5: | |
| self.row += 1 | |
| self.col += 1 | |
| elif action == 6: | |
| self.row += 1 | |
| self.col -= 1 | |
| elif action == 7: | |
| self.row -= 1 | |
| self.col -= 1 | |
| elif action == 8: | |
| self.row -= 1 | |
| self.col += 1 | |
| else: | |
| print("agent can only move NESW/1234") | |
| sys.exit() | |
| class Target: | |
| def __init__(self, row, col, ID, time_found=np.nan): | |
| self.row = row | |
| self.col = col | |
| self.ID = ID | |
| self.time_found = time_found | |
| self.status = None | |
| self.time_visited = time_found | |
| def getLocation(self): | |
| return self.row, self.col | |
| def updateFound(self, timeStep): | |
| if np.isnan(self.time_found): | |
| self.time_found = timeStep | |
| def updateVisited(self, timeStep): | |
| if np.isnan(self.time_visited): | |
| self.time_visited = timeStep | |
| if __name__ == "__main__": | |
| search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS) | |
| IS_info_map = search_env.segmentation_info_mask | |
| IS_agent_loc = search_env.start_positions | |
| IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]] | |
| env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH)) | |
| env.is_scenario(NUM_EPS_STEPS) | |
| print() | |