Leacb4 commited on
Commit
11dbd66
verified
1 Parent(s): 65073e2

Upload evaluation/basic_test_generalized.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/basic_test_generalized.py +425 -0
evaluation/basic_test_generalized.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generalized evaluation of the main model with sub-module comparison.
3
+ This file evaluates the main model's performance by comparing specialized parts
4
+ (color and hierarchy) with corresponding specialized models. It calculates similarity
5
+ matrices, linear projections between embedding spaces, and generates detailed statistics
6
+ on alignment between different representations.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import argparse
12
+ import config
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import pandas as pd
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+ from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers
19
+ from tqdm.auto import tqdm
20
+
21
+ # Local imports
22
+ from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer
23
+ from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim
24
+ from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
25
+
26
+
27
+ def load_color_model(color_model_path, color_emb_dim, device):
28
+ # Load color model
29
+ color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True)
30
+ color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device)
31
+ color_model.load_state_dict(color_checkpoint)
32
+
33
+ # Load and set the tokenizer
34
+ tokenizer = Tokenizer()
35
+ with open(config.tokeniser_path, 'r') as f:
36
+ vocab_dict = json.load(f)
37
+ color_model.tokenizer = tokenizer
38
+
39
+ color_model.eval()
40
+ return color_model
41
+
42
+
43
+ def get_emb_color_model(color_model, image_path_to_encode, text_to_encode):
44
+ # Load and preprocess image
45
+ image = Image.open(image_path_to_encode).convert('RGB')
46
+
47
+ transform = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
+ ])
52
+
53
+ processed_image = transform(image)
54
+
55
+ # Get embeddings
56
+ processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
57
+ with torch.no_grad():
58
+ image_emb = color_model.image_encoder(processed_image_batch)
59
+
60
+ # Text embedding via tokenizer + text_encoder
61
+ token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device)
62
+ lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device)
63
+ with torch.no_grad():
64
+ txt_emb = color_model.text_encoder(token_ids, lengths)
65
+
66
+ return image_emb, txt_emb
67
+
68
+ def load_main_model(main_model_path, device):
69
+ checkpoint = torch.load(main_model_path, map_location=device)
70
+ main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
71
+ state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
72
+ try:
73
+ main_model.load_state_dict(state, strict=False)
74
+ except Exception:
75
+ # Fallback: filter matching keys
76
+ model_state = main_model.state_dict()
77
+ filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape}
78
+ main_model.load_state_dict(filtered, strict=False)
79
+ main_model.to(device)
80
+ main_model.eval()
81
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
82
+ return main_model, processor
83
+
84
+
85
+ def load_hierarchy_model(hierarchy_model_path, device):
86
+ checkpoint = torch.load(hierarchy_model_path, map_location=device)
87
+ hierarchy_classes = checkpoint.get('hierarchy_classes', [])
88
+ model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device)
89
+ model.load_state_dict(checkpoint['model_state'])
90
+ extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
91
+ model.set_hierarchy_extractor(extractor)
92
+ model.eval()
93
+ return model
94
+
95
+
96
+ def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode):
97
+ image = Image.open(image_path_to_encode).convert('RGB')
98
+ transform = transforms.Compose([
99
+ transforms.Resize((224, 224)),
100
+ transforms.ToTensor(),
101
+ ])
102
+ image_tensor = transform(image).unsqueeze(0).to(device)
103
+
104
+ with torch.no_grad():
105
+ img_emb = hierarchy_model.get_image_embeddings(image_tensor)
106
+ txt_emb = hierarchy_model.get_text_embeddings(text_to_encode)
107
+
108
+ return img_emb, txt_emb
109
+
110
+ def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode):
111
+ image = Image.open(image_path_to_encode).convert('RGB')
112
+ transform = transforms.Compose([
113
+ transforms.Resize((224, 224)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
116
+ ])
117
+ image = transform(image)
118
+ image = image.unsqueeze(0).to(device)
119
+ # Prepare text inputs via processor
120
+ text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True)
121
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
122
+ outputs = main_model(**text_inputs, pixel_values=image)
123
+ text_emb = outputs.text_embeds
124
+ image_emb = outputs.image_embeds
125
+
126
+ return text_emb, image_emb
127
+
128
+
129
+ if __name__ == '__main__':
130
+ parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices')
131
+ parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth')
132
+ parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt')
133
+ parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv')
134
+ parser.add_argument('--color-emb-dim', type=int, default=16)
135
+ parser.add_argument('--num-samples', type=int, default=200)
136
+ parser.add_argument('--seed', type=int, default=42)
137
+ parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img',
138
+ choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
139
+ 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'])
140
+ parser.add_argument('--top-k', type=int, default=30)
141
+ parser.add_argument('--heatmap', action='store_true')
142
+ parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1')
143
+ args = parser.parse_args()
144
+
145
+ main_checkpoint = args.main_checkpoint
146
+ color_checkpoint = args.color_checkpoint
147
+ csv = args.csv
148
+ color_emb_dim = args.color_emb_dim
149
+ num_samples = args.num_samples
150
+ seed = args.seed
151
+ primary_metric = args.primary_metric
152
+ top_k = args.top_k
153
+ l2_grid = [float(x) for x in args.l2_grid.split(',') if x]
154
+ device = torch.device("mps")
155
+
156
+ df = pd.read_csv(csv)
157
+
158
+ # Normalize colors (reduce aliasing and sparsity)
159
+ def normalize_color(c):
160
+ if pd.isna(c):
161
+ return c
162
+ s = str(c).strip().lower()
163
+ aliases = {
164
+ 'grey': 'gray',
165
+ 'navy blue': 'navy',
166
+ 'light blue': 'blue',
167
+ 'dark blue': 'blue',
168
+ 'light grey': 'gray',
169
+ 'dark grey': 'gray',
170
+ 'light gray': 'gray',
171
+ 'dark gray': 'gray',
172
+ }
173
+ return aliases.get(s, s)
174
+
175
+ if config.color_column in df.columns:
176
+ df[config.color_column] = df[config.color_column].apply(normalize_color)
177
+
178
+ color_model = load_color_model(color_checkpoint, color_emb_dim, device)
179
+ main_model, processor = load_main_model(main_checkpoint, device)
180
+ hierarchy_model = load_hierarchy_model(hierarchy_model_path, device)
181
+
182
+ # Results container
183
+ results = []
184
+
185
+ # Accumulators for projection (A: main part, B: small model)
186
+ color_txt_As, color_txt_Bs = [], []
187
+ color_img_As, color_img_Bs = [], []
188
+ hier_txt_As, hier_txt_Bs = [], []
189
+ hier_img_As, hier_img_Bs = [], []
190
+
191
+ # Ensure determinism for sampling
192
+ pd.options.mode.copy_on_write = True
193
+ rng = pd.Series(range(len(df)), dtype=int)
194
+ _ = rng # silence lint
195
+ torch.manual_seed(seed)
196
+
197
+ unique_hiers = sorted(df[config.hierarchy_column].dropna().unique())
198
+ unique_colors = sorted(df[config.color_column].dropna().unique())
199
+
200
+ # Progress bar across all (hierarchy, color) pairs
201
+ total_pairs = len(unique_hiers) * len(unique_colors)
202
+ pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False)
203
+ for hierarchy in unique_hiers:
204
+ for color in unique_colors:
205
+ group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)]
206
+
207
+ # Sample up to num_samples per (hierarchy, color)
208
+ k = min(num_samples, len(group))
209
+ group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k]
210
+
211
+ # Progress bar for samples within the pair
212
+ inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False)
213
+ for row_idx, (_, example) in enumerate(group_iter.iterrows()):
214
+ try:
215
+ image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text'])
216
+ image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text'])
217
+ text_emb_main_model, image_emb_main_model = get_emb_main_model(
218
+ main_model, processor, example['local_image_path'], example['text']
219
+ )
220
+
221
+ color_part_txt = text_emb_main_model[:, :color_emb_dim]
222
+ color_part_img = image_emb_main_model[:, :color_emb_dim]
223
+ hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
224
+ hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
225
+
226
+ # L2-normalize parts and small-model embeddings for stable cosine
227
+ color_part_txt = F.normalize(color_part_txt, dim=1)
228
+ color_part_img = F.normalize(color_part_img, dim=1)
229
+ hier_part_txt = F.normalize(hier_part_txt, dim=1)
230
+ hier_part_img = F.normalize(hier_part_img, dim=1)
231
+ txt_emb = F.normalize(txt_emb, dim=1)
232
+ image_emb = F.normalize(image_emb, dim=1)
233
+ txt_emb_hier = F.normalize(txt_emb_hier, dim=1)
234
+ image_emb_hier = F.normalize(image_emb_hier, dim=1)
235
+
236
+ sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item()
237
+ sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item()
238
+ sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item()
239
+ sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item()
240
+
241
+ sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item()
242
+ sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item()
243
+
244
+ # Accumulate for projection fitting later
245
+ color_txt_As.append(color_part_txt.squeeze(0).detach().cpu())
246
+ color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu())
247
+ color_img_As.append(color_part_img.squeeze(0).detach().cpu())
248
+ color_img_Bs.append(image_emb.squeeze(0).detach().cpu())
249
+
250
+ hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu())
251
+ hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu())
252
+ hier_img_As.append(hier_part_img.squeeze(0).detach().cpu())
253
+ hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu())
254
+
255
+ results.append({
256
+ 'hierarchy' "hierarchy",
257
+ 'color': color,
258
+ 'row_index': int(row_idx),
259
+ 'sim_txt_color_part': float(sim_txt_color_part),
260
+ 'sim_img_color_part': float(sim_img_color_part),
261
+ 'sim_color_txt_img': float(sim_color_txt_img),
262
+ 'sim_small_txt_img': float(sim_small_txt_img),
263
+ 'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part),
264
+ 'sim_img_hierarchy_part': float(sim_img_hierarchy_part),
265
+ })
266
+ except Exception as e:
267
+ print(f"Skipping example due to error: {e}")
268
+ finally:
269
+ inner_pbar.update(1)
270
+ inner_pbar.close()
271
+ pair_pbar.update(1)
272
+ pair_pbar.close()
273
+
274
+ results_df = pd.DataFrame(results)
275
+
276
+ # Save raw results
277
+ os.makedirs('evaluation_outputs', exist_ok=True)
278
+ raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv')
279
+ results_df.to_csv(raw_path, index=False)
280
+ print(f"Saved raw similarities to {raw_path}")
281
+
282
+ # Intelligent averages
283
+ metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
284
+ 'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']
285
+
286
+ # Overall means
287
+ overall_means = results_df[metrics].mean().to_frame(name='mean').T
288
+ overall_means.insert(0, 'level', 'overall')
289
+
290
+ # By hierarchy
291
+ by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index()
292
+ by_hierarchy.insert(0, 'level', config.hierarchy_column)
293
+
294
+ # By color
295
+ by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index()
296
+ by_color.insert(0, 'level', config.color_column)
297
+
298
+ # By hierarchy+color
299
+ by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
300
+ by_pair.insert(0, 'level', 'hierarchy_color')
301
+
302
+ summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True)
303
+ summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv')
304
+ summary_df.to_csv(summary_path, index=False)
305
+ print(f"Saved summary statistics to {summary_path}")
306
+
307
+ # =====================
308
+ # Similarity matrices for best hierarchy-color combinations
309
+ # =====================
310
+ try:
311
+ by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
312
+ top_pairs = by_pair_core.nlargest(top_k, primary_metric)
313
+ matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric)
314
+ os.makedirs('evaluation_outputs', exist_ok=True)
315
+ matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv')
316
+ matrix.to_csv(matrix_csv_path)
317
+ print(f"Saved similarity matrix to {matrix_csv_path}")
318
+
319
+ if args.heatmap:
320
+ try:
321
+ import seaborn as sns
322
+ import matplotlib.pyplot as plt
323
+ plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index))))
324
+ sns.heatmap(matrix, annot=False, cmap='viridis')
325
+ plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}')
326
+ heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png')
327
+ plt.tight_layout()
328
+ plt.savefig(heatmap_path, dpi=200)
329
+ plt.close()
330
+ print(f"Saved similarity heatmap to {heatmap_path}")
331
+ except Exception as e:
332
+ print(f"Skipping heatmap generation: {e}")
333
+ except Exception as e:
334
+ print(f"Skipping matrix generation: {e}")
335
+
336
+ # =====================
337
+ # Learn projections A->B and report projected cosine means
338
+ # =====================
339
+ def fit_ridge_projection(A, B, l2_reg=1e-3):
340
+ # A: [N, D_in], B: [N, D_out]
341
+ A = torch.stack(A) # [N, D_in]
342
+ B = torch.stack(B) # [N, D_out]
343
+ # Closed-form ridge: W = (A^T A + 位I)^-1 A^T B
344
+ AtA = A.T @ A
345
+ D_in = AtA.shape[0]
346
+ AtA_reg = AtA + l2_reg * torch.eye(D_in)
347
+ W = torch.linalg.solve(AtA_reg, A.T @ B)
348
+ return W # [D_in, D_out]
349
+
350
+ def fit_ridge_with_cv(A, B, l2_values):
351
+ # Simple holdout CV: 80/20 split
352
+ if len(A) < 10:
353
+ # Not enough data for split; fallback to middle lambda
354
+ best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)]
355
+ W = fit_ridge_projection(A, B, best_l2)
356
+ return W, best_l2, None
357
+
358
+ N = len(A)
359
+ idx = torch.randperm(N)
360
+ split = int(0.8 * N)
361
+ train_idx = idx[:split]
362
+ val_idx = idx[split:]
363
+
364
+ A_tensor = torch.stack(A)
365
+ B_tensor = torch.stack(B)
366
+
367
+ A_train, B_train = A_tensor[train_idx], B_tensor[train_idx]
368
+ A_val, B_val = A_tensor[val_idx], B_tensor[val_idx]
369
+
370
+ def to_list(t):
371
+ return [row for row in t]
372
+
373
+ best_l2 = None
374
+ best_score = -1.0
375
+ for l2 in l2_values:
376
+ W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2)
377
+ score = mean_projected_cosine(to_list(A_val), to_list(B_val), W)
378
+ if score > best_score:
379
+ best_score = score
380
+ best_l2 = l2
381
+
382
+ # Refit on all with best_l2
383
+ W_best = fit_ridge_projection(A, B, best_l2)
384
+ return W_best, best_l2, best_score
385
+
386
+ def mean_projected_cosine(A, B, W):
387
+ A = torch.stack(A)
388
+ B = torch.stack(B)
389
+ A_proj = A @ W
390
+ A_proj = F.normalize(A_proj, dim=1)
391
+ B = F.normalize(B, dim=1)
392
+ return torch.mean(torch.sum(A_proj * B, dim=1)).item()
393
+
394
+ projection_report = {}
395
+
396
+ if len(color_txt_As) >= 8:
397
+ W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid)
398
+ projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct)
399
+ projection_report['proj_txt_color_part_best_l2'] = best_l2_ct
400
+ if cv_ct is not None:
401
+ projection_report['proj_txt_color_part_cv_val'] = cv_ct
402
+ if len(color_img_As) >= 8:
403
+ W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid)
404
+ projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci)
405
+ projection_report['proj_img_color_part_best_l2'] = best_l2_ci
406
+ if cv_ci is not None:
407
+ projection_report['proj_img_color_part_cv_val'] = cv_ci
408
+ if len(hier_txt_As) >= 8:
409
+ W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid)
410
+ projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht)
411
+ projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht
412
+ if cv_ht is not None:
413
+ projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht
414
+ if len(hier_img_As) >= 8:
415
+ W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid)
416
+ projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi)
417
+ projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi
418
+ if cv_hi is not None:
419
+ projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi
420
+
421
+ proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json')
422
+ with open(proj_summary_path, 'w') as f:
423
+ json.dump(projection_report, f, indent=2)
424
+ print(f"Saved projection summary to {proj_summary_path}")
425
+