griffingoodwin04 commited on
Commit
0affdc2
·
1 Parent(s): e5d6619

refactor configuration files and update paths

Browse files
analysis/ablation_lollipop.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pandas as pd
3
  import numpy as np
@@ -15,7 +16,7 @@ def setup_barlow_font():
15
  rcParams['font.family'] = 'Barlow'
16
  else:
17
  for path in ['/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
18
- '/Users/griffingoodwin/Library/Fonts/Barlow-Regular.otf']:
19
  if os.path.exists(path):
20
  fm.fontManager.addfont(path)
21
  rcParams['font.family'] = 'Barlow'
@@ -27,9 +28,6 @@ def setup_barlow_font():
27
 
28
  setup_barlow_font()
29
 
30
- DATA_DIR = "/Users/griffingoodwin/Documents/gitrepos/FOXES/Untracked/data"
31
- BASELINE_CSV = "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
32
-
33
  WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335","STEREO"]
34
  LABELS = {
35
  "94": "Ablate 94 Å",
@@ -67,96 +65,110 @@ def compute_row(label, gt, pred, is_baseline=False):
67
  row[cls] = np.mean(np.abs(np.log10(gt[m]) - np.log10(pred[m]))) if m.sum() > 5 else np.nan
68
  return row
69
 
70
- records = []
71
-
72
- # Baseline
73
- bl = pd.read_csv(BASELINE_CSV)
74
- records.append(compute_row("FOXES (no ablation)",
75
- bl["groundtruth"].values, bl["predictions"].values,
76
- is_baseline=True))
77
-
78
- for wav in WAVELENGTHS:
79
- ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
80
- records.append(compute_row(LABELS[wav], ab["groundtruth"].values, ab["predictions"].values))
81
-
82
- # Sort ablation rows by overall MAE (worst first), keep baseline pinned at bottom
83
- ablation_df = pd.DataFrame([r for r in records if not r["is_baseline"]])
84
- ablation_df = ablation_df.sort_values("overall", ascending=False).reset_index(drop=True)
85
- baseline_df = pd.DataFrame([r for r in records if r["is_baseline"]])
86
- df = pd.concat([ablation_df, baseline_df], ignore_index=True)
87
-
88
- # ── Plot ───────────────────────────────────────────────────────────────────────
89
- n_rows = len(df)
90
- fig, ax = plt.subplots(figsize=(11, 0.6 * n_rows + 1.5))
91
- #ax.set_facecolor("#FAFAFA")
92
- fig.patch.set_facecolor("#FFFFFF")
93
-
94
- y_positions = np.arange(n_rows)
95
-
96
- # Separator line between ablations and baseline
97
- ax.axhline(y=n_rows - 1.5, color="#BBBBBB", linewidth=1, linestyle=":", zorder=1)
98
-
99
- for i, row in df.iterrows():
100
- y = y_positions[i]
101
- is_bl = row["is_baseline"]
102
-
103
- # Highlight baseline row
104
- if is_bl:
105
- ax.axhspan(y - 0.45, y + 0.45, color="#EEF6FF", zorder=0)
106
-
107
- # Span line across per-class range
108
- class_vals = [row[c] for c in FLARE_CLASSES if not np.isnan(row[c])]
109
- if class_vals:
110
- ax.hlines(y, min(class_vals), max(class_vals),
111
- color="#CCCCCC", linewidth=2, zorder=1)
112
-
113
- # Stem from 0 to overall
114
- ax.hlines(y, 0, row["overall"],
115
- color="#AAAAAA", linewidth=1.2, linestyle="--", zorder=0, alpha=0.6)
116
-
117
- # Per-class dots
118
- for cls in FLARE_CLASSES:
119
- val = row[cls]
120
- if not np.isnan(val):
121
- ax.scatter(val, y, color=CLASS_COLORS[cls], s=80, zorder=4,
122
- edgecolors="white", linewidths=0.6, alpha=0.75)
123
-
124
- # Overall dot
125
- outline_color = "#1A6BBF" if is_bl else "black"
126
- ax.scatter(row["overall"], y, color="white", s=190, zorder=3,
127
- edgecolors=outline_color, linewidths=2.0 if is_bl else 1.5, alpha=0.75)
128
- ax.scatter(row["overall"], y, color=outline_color, s=75, zorder=3,
129
- marker="|", linewidths=1.5, alpha=0.75)
130
-
131
- tick_colors = ["black"] * n_rows
132
- tick_colors[-1] = "#1A6BBF" # baseline label in blue
133
- ax.set_yticks(y_positions)
134
- ax.set_yticklabels(df["label"], fontsize=12)
135
- for ticklabel, color in zip(ax.get_yticklabels(), tick_colors):
136
- ticklabel.set_color(color)
137
- if color != "black":
138
- ticklabel.set_fontweight("bold")
139
- ax.set_xlabel("MAE (log$_{10}$ scale)", fontsize=12)
140
- ax.grid(True, axis="x", alpha=0.4, color="#CCCCCC", linewidth=0.6)
141
- ax.set_axisbelow(True)
142
- ax.spines[["top", "right"]].set_visible(False)
143
- ax.tick_params(axis="y", length=0, labelsize=11)
144
- ax.tick_params(axis="x", labelsize=10)
145
-
146
- # Legend
147
- class_patches = [
148
- mpatches.Patch(color=CLASS_COLORS[c], label=f"{c}-class") for c in FLARE_CLASSES
149
- ]
150
- overall_patch = mpatches.Patch(facecolor="white", edgecolor="black", label="Overall")
151
- #baseline_patch = mpatches.Patch(facecolor="white", edgecolor="#1A6BBF", label="Baseline (overall)")
152
- ax.legend(handles=class_patches + [overall_patch],
153
- loc="upper right", fontsize=10, framealpha=0.9,
154
- edgecolor="#CCCCCC")
155
-
156
- # ax.set_title("Ablation Study — Log MAE by Channel & Flare Class",
157
- # fontsize=14, fontweight="bold", pad=14)
158
- plt.xlim(0, .85)
159
- plt.tight_layout()
160
- plt.savefig("ablation_lollipop.png", dpi=450, bbox_inches="tight")
161
- plt.show()
162
- print("Saved: analysis/ablation_lollipop.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import os
3
  import pandas as pd
4
  import numpy as np
 
16
  rcParams['font.family'] = 'Barlow'
17
  else:
18
  for path in ['/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
19
+ os.path.expanduser('~/Library/Fonts/Barlow-Regular.otf')]:
20
  if os.path.exists(path):
21
  fm.fontManager.addfont(path)
22
  rcParams['font.family'] = 'Barlow'
 
28
 
29
  setup_barlow_font()
30
 
 
 
 
31
  WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335","STEREO"]
32
  LABELS = {
33
  "94": "Ablate 94 Å",
 
65
  row[cls] = np.mean(np.abs(np.log10(gt[m]) - np.log10(pred[m]))) if m.sum() > 5 else np.nan
66
  return row
67
 
68
+ if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser(description="Ablation lollipop plot")
70
+ parser.add_argument("--data_dir", required=True,
71
+ help="Directory containing ablate_<wavelength>_global_1.csv files")
72
+ parser.add_argument("--baseline_csv", required=True,
73
+ help="Baseline predictions CSV (groundtruth + predictions columns)")
74
+ parser.add_argument("--out", default="ablation_lollipop.png",
75
+ help="Output image path (default: ablation_lollipop.png)")
76
+ args = parser.parse_args()
77
+
78
+ DATA_DIR = args.data_dir
79
+ BASELINE_CSV = args.baseline_csv
80
+ OUT_PATH = args.out
81
+
82
+ records = []
83
+
84
+ # Baseline
85
+ bl = pd.read_csv(BASELINE_CSV)
86
+ records.append(compute_row("FOXES (no ablation)",
87
+ bl["groundtruth"].values, bl["predictions"].values,
88
+ is_baseline=True))
89
+
90
+ for wav in WAVELENGTHS:
91
+ ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
92
+ records.append(compute_row(LABELS[wav], ab["groundtruth"].values, ab["predictions"].values))
93
+
94
+ # Sort ablation rows by overall MAE (worst first), keep baseline pinned at bottom
95
+ ablation_df = pd.DataFrame([r for r in records if not r["is_baseline"]])
96
+ ablation_df = ablation_df.sort_values("overall", ascending=False).reset_index(drop=True)
97
+ baseline_df = pd.DataFrame([r for r in records if r["is_baseline"]])
98
+ df = pd.concat([ablation_df, baseline_df], ignore_index=True)
99
+
100
+ # ── Plot ───────────────────────────────────────────────────────────────────────
101
+ n_rows = len(df)
102
+ fig, ax = plt.subplots(figsize=(11, 0.6 * n_rows + 1.5))
103
+ #ax.set_facecolor("#FAFAFA")
104
+ fig.patch.set_facecolor("#FFFFFF")
105
+
106
+ y_positions = np.arange(n_rows)
107
+
108
+ # Separator line between ablations and baseline
109
+ ax.axhline(y=n_rows - 1.5, color="#BBBBBB", linewidth=1, linestyle=":", zorder=1)
110
+
111
+ for i, row in df.iterrows():
112
+ y = y_positions[i]
113
+ is_bl = row["is_baseline"]
114
+
115
+ # Highlight baseline row
116
+ if is_bl:
117
+ ax.axhspan(y - 0.45, y + 0.45, color="#EEF6FF", zorder=0)
118
+
119
+ # Span line across per-class range
120
+ class_vals = [row[c] for c in FLARE_CLASSES if not np.isnan(row[c])]
121
+ if class_vals:
122
+ ax.hlines(y, min(class_vals), max(class_vals),
123
+ color="#CCCCCC", linewidth=2, zorder=1)
124
+
125
+ # Stem from 0 to overall
126
+ ax.hlines(y, 0, row["overall"],
127
+ color="#AAAAAA", linewidth=1.2, linestyle="--", zorder=0, alpha=0.6)
128
+
129
+ # Per-class dots
130
+ for cls in FLARE_CLASSES:
131
+ val = row[cls]
132
+ if not np.isnan(val):
133
+ ax.scatter(val, y, color=CLASS_COLORS[cls], s=80, zorder=4,
134
+ edgecolors="white", linewidths=0.6, alpha=0.75)
135
+
136
+ # Overall dot
137
+ outline_color = "#1A6BBF" if is_bl else "black"
138
+ ax.scatter(row["overall"], y, color="white", s=190, zorder=3,
139
+ edgecolors=outline_color, linewidths=2.0 if is_bl else 1.5, alpha=0.75)
140
+ ax.scatter(row["overall"], y, color=outline_color, s=75, zorder=3,
141
+ marker="|", linewidths=1.5, alpha=0.75)
142
+
143
+ tick_colors = ["black"] * n_rows
144
+ tick_colors[-1] = "#1A6BBF" # baseline label in blue
145
+ ax.set_yticks(y_positions)
146
+ ax.set_yticklabels(df["label"], fontsize=12)
147
+ for ticklabel, color in zip(ax.get_yticklabels(), tick_colors):
148
+ ticklabel.set_color(color)
149
+ if color != "black":
150
+ ticklabel.set_fontweight("bold")
151
+ ax.set_xlabel("MAE (log$_{10}$ scale)", fontsize=12)
152
+ ax.grid(True, axis="x", alpha=0.4, color="#CCCCCC", linewidth=0.6)
153
+ ax.set_axisbelow(True)
154
+ ax.spines[["top", "right"]].set_visible(False)
155
+ ax.tick_params(axis="y", length=0, labelsize=11)
156
+ ax.tick_params(axis="x", labelsize=10)
157
+
158
+ # Legend
159
+ class_patches = [
160
+ mpatches.Patch(color=CLASS_COLORS[c], label=f"{c}-class") for c in FLARE_CLASSES
161
+ ]
162
+ overall_patch = mpatches.Patch(facecolor="white", edgecolor="black", label="Overall")
163
+ #baseline_patch = mpatches.Patch(facecolor="white", edgecolor="#1A6BBF", label="Baseline (overall)")
164
+ ax.legend(handles=class_patches + [overall_patch],
165
+ loc="upper right", fontsize=10, framealpha=0.9,
166
+ edgecolor="#CCCCCC")
167
+
168
+ # ax.set_title("Ablation Study — Log MAE by Channel & Flare Class",
169
+ # fontsize=14, fontweight="bold", pad=14)
170
+ plt.xlim(0, .85)
171
+ plt.tight_layout()
172
+ plt.savefig(OUT_PATH, dpi=450, bbox_inches="tight")
173
+ plt.show()
174
+ print(f"Saved: {OUT_PATH}")
analysis/spatial_performance.py CHANGED
@@ -35,10 +35,10 @@ sys.path.insert(0, str(PROJECT_ROOT))
35
  from forecasting.inference.evaluation import setup_barlow_font
36
 
37
  # ---------------------------------------------------------------------------
38
- # Paths — edit here
39
  # ---------------------------------------------------------------------------
40
- FLUX_DIR = "/Volumes/T9/FOXES_Data/flux/"
41
- PREDICTIONS_CSV = "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
42
  OUT_DIR = Path(__file__).parent
43
  GRID_SIZE = 64 # 512px / 8px patch size
44
  BIN_SIZE = 1 # downsample factor (1 = full 64×64 resolution)
 
35
  from forecasting.inference.evaluation import setup_barlow_font
36
 
37
  # ---------------------------------------------------------------------------
38
+ # Paths — override via CLI args or environment variables
39
  # ---------------------------------------------------------------------------
40
+ FLUX_DIR = os.environ.get("FOXES_FLUX_DIR", "")
41
+ PREDICTIONS_CSV = os.environ.get("FOXES_PREDICTIONS_CSV", "")
42
  OUT_DIR = Path(__file__).parent
43
  GRID_SIZE = 64 # 512px / 8px patch size
44
  BIN_SIZE = 1 # downsample factor (1 = full 64×64 resolution)
forecasting/inference/ablation_inference_config.yaml CHANGED
@@ -11,7 +11,7 @@
11
  # Define top-level string keys and reference them anywhere with ${key}.
12
 
13
  base_dir: "/Volumes/T9/FOXES_Data"
14
- checkpoint: "/Users/griffingoodwin/Downloads/FOXES_Model_Checkpoint.ckpt"
15
 
16
  model: "ViTLocal"
17
  wavelengths: [94, 131, 171, 193, 211, 304, 335]
 
11
  # Define top-level string keys and reference them anywhere with ${key}.
12
 
13
  base_dir: "/Volumes/T9/FOXES_Data"
14
+ checkpoint: "" # Path to your model checkpoint (.ckpt)
15
 
16
  model: "ViTLocal"
17
  wavelengths: [94, 131, 171, 193, 211, 304, 335]
forecasting/inference/evaluation.py CHANGED
@@ -39,7 +39,7 @@ def setup_barlow_font():
39
  else:
40
  # Try alternative approach - directly specify font file
41
  barlow_path = '/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf'
42
- barlow_path2 = '/Users/griffingoodwin/Library/Fonts/Barlow-Regular.otf'
43
  if os.path.exists(barlow_path):
44
  # Add the font file directly to matplotlib
45
  fm.fontManager.addfont(barlow_path)
 
39
  else:
40
  # Try alternative approach - directly specify font file
41
  barlow_path = '/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf'
42
+ barlow_path2 = os.path.expanduser('~/Library/Fonts/Barlow-Regular.otf')
43
  if os.path.exists(barlow_path):
44
  # Add the font file directly to matplotlib
45
  fm.fontManager.addfont(barlow_path)
forecasting/training/train_config.yaml CHANGED
@@ -44,12 +44,12 @@ data:
44
  "${base_checkpoint_dir}/new-checkpoint/"
45
 
46
  wandb:
47
- entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
48
- project: Paper
49
  job_type: training
50
  tags:
51
  - aia
52
  - sxr
53
  - regression
54
- run_name: paper-8-patch-4ch
55
- notes: Regression from AIA images to SXR images using ViTLocal model with 8x8 patches
 
44
  "${base_checkpoint_dir}/new-checkpoint/"
45
 
46
  wandb:
47
+ entity: "" # Set to your W&B username or team name
48
+ project: FOXES
49
  job_type: training
50
  tags:
51
  - aia
52
  - sxr
53
  - regression
54
+ run_name: run_1
55
+ notes: AIA to SXR translation
pipeline_config.yaml CHANGED
@@ -14,7 +14,7 @@
14
  # Change base_dir or checkpoint once and every path updates automatically.
15
 
16
  base_dir: "/Volumes/T9/FOXES_Data"
17
- checkpoint: "/Users/griffingoodwin/Downloads/FOXES_Model_Checkpoint.ckpt"
18
 
19
  # -----------------------------------------------------------------------------
20
  # HuggingFace download (step: hf_download)
@@ -102,7 +102,7 @@ train:
102
  batch_size: 6
103
  wandb:
104
  run_name: "pipeline-run"
105
- entity: jayantbiradar619-university-of-arizona
106
  project: Paper
107
  job_type: training
108
  tags:
 
14
  # Change base_dir or checkpoint once and every path updates automatically.
15
 
16
  base_dir: "/Volumes/T9/FOXES_Data"
17
+ checkpoint: "" # Path to your model checkpoint (.ckpt)
18
 
19
  # -----------------------------------------------------------------------------
20
  # HuggingFace download (step: hf_download)
 
102
  batch_size: 6
103
  wandb:
104
  run_name: "pipeline-run"
105
+ entity: "" # Set to your W&B username or team name
106
  project: Paper
107
  job_type: training
108
  tags: