| # ============================================================================= | |
| # FOXES Ablation Inference Configuration | |
| # ============================================================================= | |
| # Used by ablation_inference.py to run Gaussian noise channel-masking ablation. | |
| # | |
| # Usage: | |
| # python ablation_inference.py -config ablation_inference_config.yaml | |
| # | |
| # Variables | |
| # --------- | |
| # Define top-level string keys and reference them anywhere with ${key}. | |
| base_dir: "/Volumes/T9/FOXES_Data" | |
| checkpoint: "" # Path to your model checkpoint (.ckpt) | |
| model: "ViTLocal" | |
| wavelengths: [94, 131, 171, 193, 211, 304, 335] | |
| prediction_only: "false" | |
| data: | |
| aia_dir: "${base_dir}/AIA_processed/test" | |
| sxr_dir: "${base_dir}/SXR_processed/test" | |
| sxr_norm_path: "${base_dir}/SXR_processed/normalized_sxr.npy" | |
| checkpoint_path: "${checkpoint}" | |
| # Output directory — each condition saves <label>.csv here | |
| output_dir: "${base_dir}/inference/ablation" | |
| model_params: | |
| input_size: 512 | |
| patch_size: 8 | |
| batch_size: 10 | |
| no_weights: true # skip attention saving for speed | |
| no_flux: true | |
| num_workers: 8 # set to 0 to disable multiprocessing (useful for debugging stalls) | |
| pin_memory: true | |
| use_amp: false # set true to enable FP16 autocast (faster on A100, may OOM on V100) | |
| multi_gpu: false # set true to use all available GPUs via DataParallel | |
| # ----------------------------------------------------------------------------- | |
| # Ablation conditions | |
| # ----------------------------------------------------------------------------- | |
| # noise_std: scale factor applied to each channel's per-image std. | |
| # Each condition names a wavelength subset to corrupt; empty list = clean baseline. | |
| # You can also pass a per-wavelength dict, e.g. {171: 0.5, 193: 2.0} | |
| noise_std: 1.0 | |
| conditions: | |
| - label: baseline | |
| wavelengths: [] | |
| - label: ablate_94 | |
| wavelengths: [94] | |
| - label: ablate_131 | |
| wavelengths: [131] | |
| - label: ablate_171 | |
| wavelengths: [171] | |
| - label: ablate_193 | |
| wavelengths: [193] | |
| - label: ablate_211 | |
| wavelengths: [211] | |
| - label: ablate_304 | |
| wavelengths: [304] | |
| - label: ablate_335 | |
| wavelengths: [335] | |
| - label: ablate_all | |
| wavelengths: [94, 131, 171, 193, 211, 304, 335] | |