Spaces:
Running
on
Zero
Running
on
Zero
| ############################################################################################ | |
| # Name: test_parameter.py | |
| # | |
| # NOTE: Change all your hyper-params here for eval | |
| # Simple How-To Guide: | |
| # 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True | |
| # 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False | |
| # 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False | |
| ############################################################################################ | |
| import os | |
| import sys | |
| sys.modules['TRAINING'] = False # False = Inference Testing | |
| ############################################################### | |
| # Overload Params | |
| ############################################################### | |
| OPT_VARS = {} | |
| def getenv(var_name, default=None, cast_type=str): | |
| try: | |
| value = os.environ.get(var_name, None) | |
| if value is None: | |
| result = default | |
| elif cast_type == bool: | |
| result = value.lower() in ("true", "1", "yes") | |
| else: | |
| result = cast_type(value) | |
| except (ValueError, TypeError): | |
| result = default | |
| OPT_VARS[var_name] = result # Log the result | |
| return result | |
| ############################################################### | |
| # General | |
| ############################################################### | |
| # --- GENERAL --- # | |
| USE_GPU = False | |
| NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs | |
| NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes | |
| NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int) | |
| FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index) | |
| NUM_ROBOTS = 1 # Only allow for 1 robot | |
| NUM_COORDS_WIDTH=24 # How many node coords across width? | |
| NUM_COORDS_HEIGHT=24 # How many node coords across height? | |
| CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336' | |
| SENSOR_RANGE=80 # Only applicable to 'circle' sensor model | |
| SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular) | |
| TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found | |
| FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found | |
| # --- Planner Params --- # | |
| POLICY = getenv("POLICY", default="RL", cast_type=str) | |
| NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH | |
| NUM_RUN = 1 | |
| MODEL_NAME = "avs_rl_policy.pth" | |
| INPUT_DIM = 4 | |
| EMBEDDING_DIM = 128 | |
| K_SIZE = 8 | |
| # --- Folders & Visualizations --- # | |
| GRIDMAP_SET_DIR = "maps/gpt4o/envs_val" | |
| MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH | |
| TARGETS_SET_DIR = "" | |
| # TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH | |
| OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP | |
| SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs | |
| FOLDER_NAME = 'avs_search' | |
| MODEL_PATH = f'inference/model' | |
| GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}' | |
| LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}' | |
| LOG_TEMPLATE_XLSX = f'inference/template.xlsx' | |
| CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str) | |
| VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges | |
| ####################################################################### | |
| # AVS Params | |
| ####################################################################### | |
| # General PARAMS | |
| USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR | |
| QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings) | |
| EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates | |
| QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound" | |
| STEPS_PER_TTA = 20 # no. steps before each TTA series | |
| NUM_TTA_STEPS = 1 # no. of TTA steps during each series | |
| RESET_WEIGHTS = True | |
| MIN_LR = 1e-6 | |
| MAX_LR = 1e-5 | |
| GAMMA_EXPONENT = 2 | |
| # Paths related to AVS (TRAIN w/ TARGETS) | |
| LOAD_AVS_BENCH = True # Whether to init AVS datasets | |
| AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21' | |
| AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px' | |
| AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json' | |
| AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test' | |
| AVS_GAUSSIAN_BLUR_KERNEL = (5,5) | |
| AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str) | |
| AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs | |
| AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str) | |
| AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str) | |
| ####################################################################### | |
| # UTILS | |
| ####################################################################### | |
| # COLORS (for printing) | |
| RED='\033[1;31m' | |
| GREEN='\033[1;32m' | |
| YELLOW='\033[1;93m' | |
| NC_BOLD='\033[1m' # Bold, No Color | |
| NC='\033[0m' # No Color | |