td-builder commited on
Commit
7cf8e19
·
verified ·
1 Parent(s): 73b6fbc

Upload 137 files

Browse files
Files changed (1) hide show
  1. hugging/td_fuse/techniques.py +16 -6
hugging/td_fuse/techniques.py CHANGED
@@ -359,16 +359,26 @@ def compute_transferability_masks(
359
  # For 2D weights: importance determines which rows/columns to protect
360
  if param.dim() == 2:
361
  rows, cols = param.shape
362
- # Use importance for the output dimension
363
- imp = importance[:rows] if importance.shape[0] >= rows else importance
364
 
365
  # Compute threshold: top (1-threshold) fraction is task-specific
366
- if imp.numel() > 0:
 
 
 
 
 
 
 
 
 
 
 
367
  q = torch.quantile(imp.float(), 1.0 - threshold)
368
- # True = transferable (below threshold), False = task-specific (protect)
369
- row_mask = imp < q
370
- masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
371
  else:
 
372
  masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
373
  else:
374
  # 1D params (biases, norms): default to transferable
 
359
  # For 2D weights: importance determines which rows/columns to protect
360
  if param.dim() == 2:
361
  rows, cols = param.shape
362
+ imp_size = importance.shape[0]
 
363
 
364
  # Compute threshold: top (1-threshold) fraction is task-specific
365
+ if importance.numel() == 0:
366
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
367
+ elif imp_size >= rows:
368
+ # Importance covers the row dimension (e.g., 4096 importance, 4096×4096 weight)
369
+ imp = importance[:rows]
370
+ q = torch.quantile(imp.float(), 1.0 - threshold)
371
+ row_mask = imp < q # [rows]
372
+ masks[param_name] = row_mask.unsqueeze(1).expand(rows, cols)
373
+ elif imp_size >= cols:
374
+ # Importance covers the column dimension (e.g., 4096 importance, 12288×4096 weight)
375
+ # This happens for gate_proj, up_proj where rows=3×hidden_dim
376
+ imp = importance[:cols]
377
  q = torch.quantile(imp.float(), 1.0 - threshold)
378
+ col_mask = imp < q # [cols]
379
+ masks[param_name] = col_mask.unsqueeze(0).expand(rows, cols)
 
380
  else:
381
+ # Importance doesn't match either dimension — default to transferable
382
  masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
383
  else:
384
  # 1D params (biases, norms): default to transferable