File size: 5,465 Bytes
e330ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
############################################################################################
# Name: test_parameter.py
#
# NOTE: Change all your hyper-params here for eval
# Simple How-To Guide: 
# 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True
# 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False 
# 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False 
############################################################################################

import os
import sys
sys.modules['TRAINING'] = False           # False = Inference Testing    

###############################################################
# Overload Params
###############################################################

OPT_VARS = {}
def getenv(var_name, default=None, cast_type=str):
    try:
        value = os.environ.get(var_name, None)
        if value is None:
            result = default
        elif cast_type == bool:
            result = value.lower() in ("true", "1", "yes")
        else:
            result = cast_type(value)
    except (ValueError, TypeError):
        result = default

    OPT_VARS[var_name] = result  # Log the result
    return result

###############################################################
# General
###############################################################

# --- GENERAL --- #
USE_GPU = False  
NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs
NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int)  # the number of concurrent processes
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int)
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool)  # Whether to fix the starting position of the robots (middle index)
NUM_ROBOTS = 1                        # Only allow for 1 robot
NUM_COORDS_WIDTH=24                   # How many node coords across width?
NUM_COORDS_HEIGHT=24                  # How many node coords across height?
CLIP_GRIDS_DIMS=[24,24]               # [16,16] if 'openai/clip-vit-large-patch14-336' 
SENSOR_RANGE=80                       # Only applicable to 'circle' sensor model
SENSOR_MODEL="rectangular"            # "rectangular", "circle" (NOTE: no colllision check for rectangular)
TERMINATE_ON_TGTS_FOUND = True        # Whether to terminate episode when all targets found
FORCE_LOGGING_DONE_TGTS_FOUND = True  # Whether to force csv logging when all targets found


# --- Planner Params --- # 
POLICY = getenv("POLICY", default="RL", cast_type=str)
NUM_TEST = 800  # Overriden if LOAD_AVS_BENCH 
NUM_RUN = 1
MODEL_NAME = "avs_rl_policy.pth" 
INPUT_DIM = 4
EMBEDDING_DIM = 128
K_SIZE = 8   


# --- Folders & Visualizations --- #
GRIDMAP_SET_DIR = "maps/gpt4o/envs_val"
MASK_SET_DIR = "maps/example/masks_val"      # Overriden if LOAD_AVS_BENCH
TARGETS_SET_DIR = ""
# TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts"   # Overriden if LOAD_AVS_BENCH
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str)  # Override initial score mask from CLIP 
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool)              # do you want to save GIFs
FOLDER_NAME = 'avs_search' 
MODEL_PATH = f'inference/model'
GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}'
LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}'
LOG_TEMPLATE_XLSX = f'inference/template.xlsx'
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
VIZ_GRAPH_EDGES = False  # do you want to visualize the graph edges


#######################################################################
# AVS Params
#######################################################################

# General PARAMS
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool)   # If false, use custom masks from OVERRIDE_MASK_DIR
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str)                # "" = Test all tax (can accept taxonomy substrings)
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool)         # Whether to execute TTA mask updates
QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound"
STEPS_PER_TTA = 20  # no. steps before each TTA series
NUM_TTA_STEPS = 1   # no. of TTA steps during each series
RESET_WEIGHTS = True
MIN_LR = 1e-6
MAX_LR = 1e-5   
GAMMA_EXPONENT = 2  

# Paths related to AVS (TRAIN w/ TARGETS)
LOAD_AVS_BENCH = True     # Whether to init AVS datasets 
AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21' 
AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px' 
AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json' 
AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
AVS_GAUSSIAN_BLUR_KERNEL = (5,5)
AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str)
AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool)  # If false, load locally using CHECKPOINT_PATHs
AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str)                  
AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str)   

#######################################################################
# UTILS
#######################################################################

# COLORS (for printing)
RED='\033[1;31m'          
GREEN='\033[1;32m'
YELLOW='\033[1;93m'       
NC_BOLD='\033[1m' # Bold, No Color 
NC='\033[0m'      # No Color