Upload 137 files
Browse files
hugging/td_fuse/techniques.py
CHANGED
|
@@ -359,16 +359,26 @@ def compute_transferability_masks(
|
|
| 359 |
# For 2D weights: importance determines which rows/columns to protect
|
| 360 |
if param.dim() == 2:
|
| 361 |
rows, cols = param.shape
|
| 362 |
-
|
| 363 |
-
imp = importance[:rows] if importance.shape[0] >= rows else importance
|
| 364 |
|
| 365 |
# Compute threshold: top (1-threshold) fraction is task-specific
|
| 366 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
q = torch.quantile(imp.float(), 1.0 - threshold)
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
|
| 371 |
else:
|
|
|
|
| 372 |
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 373 |
else:
|
| 374 |
# 1D params (biases, norms): default to transferable
|
|
|
|
| 359 |
# For 2D weights: importance determines which rows/columns to protect
|
| 360 |
if param.dim() == 2:
|
| 361 |
rows, cols = param.shape
|
| 362 |
+
imp_size = importance.shape[0]
|
|
|
|
| 363 |
|
| 364 |
# Compute threshold: top (1-threshold) fraction is task-specific
|
| 365 |
+
if importance.numel() == 0:
|
| 366 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 367 |
+
elif imp_size >= rows:
|
| 368 |
+
# Importance covers the row dimension (e.g., 4096 importance, 4096×4096 weight)
|
| 369 |
+
imp = importance[:rows]
|
| 370 |
+
q = torch.quantile(imp.float(), 1.0 - threshold)
|
| 371 |
+
row_mask = imp < q # [rows]
|
| 372 |
+
masks[param_name] = row_mask.unsqueeze(1).expand(rows, cols)
|
| 373 |
+
elif imp_size >= cols:
|
| 374 |
+
# Importance covers the column dimension (e.g., 4096 importance, 12288×4096 weight)
|
| 375 |
+
# This happens for gate_proj, up_proj where rows=3×hidden_dim
|
| 376 |
+
imp = importance[:cols]
|
| 377 |
q = torch.quantile(imp.float(), 1.0 - threshold)
|
| 378 |
+
col_mask = imp < q # [cols]
|
| 379 |
+
masks[param_name] = col_mask.unsqueeze(0).expand(rows, cols)
|
|
|
|
| 380 |
else:
|
| 381 |
+
# Importance doesn't match either dimension — default to transferable
|
| 382 |
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 383 |
else:
|
| 384 |
# 1D params (biases, norms): default to transferable
|