search-tta-demo / planner /test_parameter.py
derektan
Initial Commit
e330ebf
############################################################################################
# Name: test_parameter.py
#
# NOTE: Change all your hyper-params here for eval
# Simple How-To Guide:
# 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True
# 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False
# 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False
############################################################################################
import os
import sys
sys.modules['TRAINING'] = False # False = Inference Testing
###############################################################
# Overload Params
###############################################################
OPT_VARS = {}
def getenv(var_name, default=None, cast_type=str):
try:
value = os.environ.get(var_name, None)
if value is None:
result = default
elif cast_type == bool:
result = value.lower() in ("true", "1", "yes")
else:
result = cast_type(value)
except (ValueError, TypeError):
result = default
OPT_VARS[var_name] = result # Log the result
return result
###############################################################
# General
###############################################################
# --- GENERAL --- #
USE_GPU = False
NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs
NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int)
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
NUM_ROBOTS = 1 # Only allow for 1 robot
NUM_COORDS_WIDTH=24 # How many node coords across width?
NUM_COORDS_HEIGHT=24 # How many node coords across height?
CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336'
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular)
TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
# --- Planner Params --- #
POLICY = getenv("POLICY", default="RL", cast_type=str)
NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH
NUM_RUN = 1
MODEL_NAME = "avs_rl_policy.pth"
INPUT_DIM = 4
EMBEDDING_DIM = 128
K_SIZE = 8
# --- Folders & Visualizations --- #
GRIDMAP_SET_DIR = "maps/gpt4o/envs_val"
MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH
TARGETS_SET_DIR = ""
# TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
FOLDER_NAME = 'avs_search'
MODEL_PATH = f'inference/model'
GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}'
LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}'
LOG_TEMPLATE_XLSX = f'inference/template.xlsx'
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
#######################################################################
# AVS Params
#######################################################################
# General PARAMS
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings)
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound"
STEPS_PER_TTA = 20 # no. steps before each TTA series
NUM_TTA_STEPS = 1 # no. of TTA steps during each series
RESET_WEIGHTS = True
MIN_LR = 1e-6
MAX_LR = 1e-5
GAMMA_EXPONENT = 2
# Paths related to AVS (TRAIN w/ TARGETS)
LOAD_AVS_BENCH = True # Whether to init AVS datasets
AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21'
AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px'
AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json'
AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
AVS_GAUSSIAN_BLUR_KERNEL = (5,5)
AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str)
AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs
AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str)
AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str)
#######################################################################
# UTILS
#######################################################################
# COLORS (for printing)
RED='\033[1;31m'
GREEN='\033[1;32m'
YELLOW='\033[1;93m'
NC_BOLD='\033[1m' # Bold, No Color
NC='\033[0m' # No Color