FOXES / forecasting /inference /ablation_inference_config.yaml
griffingoodwin04's picture
slight bug fixes!
f3d1109
# =============================================================================
# FOXES Ablation Inference Configuration
# =============================================================================
# Used by ablation_inference.py to run Gaussian noise channel-masking ablation.
#
# Usage:
# python ablation_inference.py -config ablation_inference_config.yaml
#
# Variables
# ---------
# Define top-level string keys and reference them anywhere with ${key}.
base_dir: "/Volumes/T9/FOXES_Data"
checkpoint: "" # Path to your model checkpoint (.ckpt)
model: "ViTLocal"
wavelengths: [94, 131, 171, 193, 211, 304, 335]
prediction_only: "false"
data:
aia_dir: "${base_dir}/AIA_processed/test"
sxr_dir: "${base_dir}/SXR_processed/test"
sxr_norm_path: "${base_dir}/SXR_processed/normalized_sxr.npy"
checkpoint_path: "${checkpoint}"
# Output directory — each condition saves <label>.csv here
output_dir: "${base_dir}/inference/ablation"
model_params:
input_size: 512
patch_size: 8
batch_size: 10
no_weights: true # skip attention saving for speed
no_flux: true
num_workers: 8 # set to 0 to disable multiprocessing (useful for debugging stalls)
pin_memory: true
use_amp: false # set true to enable FP16 autocast (faster on A100, may OOM on V100)
multi_gpu: false # set true to use all available GPUs via DataParallel
# -----------------------------------------------------------------------------
# Ablation conditions
# -----------------------------------------------------------------------------
# noise_std: scale factor applied to each channel's per-image std.
# Each condition names a wavelength subset to corrupt; empty list = clean baseline.
# You can also pass a per-wavelength dict, e.g. {171: 0.5, 193: 2.0}
noise_std: 1.0
conditions:
- label: baseline
wavelengths: []
- label: ablate_94
wavelengths: [94]
- label: ablate_131
wavelengths: [131]
- label: ablate_171
wavelengths: [171]
- label: ablate_193
wavelengths: [193]
- label: ablate_211
wavelengths: [211]
- label: ablate_304
wavelengths: [304]
- label: ablate_335
wavelengths: [335]
- label: ablate_all
wavelengths: [94, 131, 171, 193, 211, 304, 335]