Upload 137 files
Browse files- hugging/td_fuse/merge.py +10 -2
hugging/td_fuse/merge.py
CHANGED
|
@@ -903,12 +903,20 @@ def run_single_merge(
|
|
| 903 |
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
| 904 |
print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush()
|
| 905 |
step_t = time.time()
|
| 906 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
|
| 908 |
# Record this merge's delta + compute ARM/OTMF for next merge
|
| 909 |
protection.after_merge(
|
| 910 |
target_model, pre_merge_state,
|
| 911 |
-
pre_merge_activations=
|
| 912 |
post_merge_activations=post_merge_activations,
|
| 913 |
)
|
| 914 |
|
|
|
|
| 903 |
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
| 904 |
print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush()
|
| 905 |
step_t = time.time()
|
| 906 |
+
arm_sample_size = 100 # Use a small subset for speed
|
| 907 |
+
post_merge_activations = extract_activations(target_model, calibration_data[:arm_sample_size])
|
| 908 |
+
|
| 909 |
+
# Slice pre_merge_target_activations to match post_merge sample count
|
| 910 |
+
# (pre_merge used all 1500 samples, post_merge uses 100 — ARM needs same shape)
|
| 911 |
+
pre_merge_activations_subset = {}
|
| 912 |
+
for key in pre_merge_target_activations:
|
| 913 |
+
act = pre_merge_target_activations[key]
|
| 914 |
+
pre_merge_activations_subset[key] = act[:arm_sample_size]
|
| 915 |
|
| 916 |
# Record this merge's delta + compute ARM/OTMF for next merge
|
| 917 |
protection.after_merge(
|
| 918 |
target_model, pre_merge_state,
|
| 919 |
+
pre_merge_activations=pre_merge_activations_subset,
|
| 920 |
post_merge_activations=post_merge_activations,
|
| 921 |
)
|
| 922 |
|