CompressedGemma commited on
Commit
c9097e7
·
verified ·
1 Parent(s): e81a80a
Files changed (1) hide show
  1. generate_imatrix.py +8 -4
generate_imatrix.py CHANGED
@@ -673,6 +673,10 @@ class TransformerRunner:
673
  for name, arr, cnt in imp_refs:
674
  self.importance[name] = (arr.astype(np.float64), cnt.value)
675
 
 
 
 
 
676
  return hidden
677
 
678
  def _hpc_rms_norm(self, x, weight, eps):
@@ -1317,7 +1321,8 @@ class TransformerRunner:
1317
  if embed_w is None:
1318
  raise RuntimeError("Missing token_embd.weight")
1319
 
1320
- hidden = embed_w[token_ids] # [seq_len, n_embd]
 
1321
 
1322
  # RoPE frequencies
1323
  cos_f, sin_f = rope_freqs(self.head_dim, seq_len, cfg['rope_base'])
@@ -1346,9 +1351,8 @@ class TransformerRunner:
1346
  if self.verbose and (layer_idx + 1) % 4 == 0:
1347
  print(f" Layer {layer_idx + 1}/{cfg['n_layers']}", end='\r')
1348
 
1349
- # Output projection
1350
- output_w = self._get_weight('output.weight')
1351
- if output_w is not None:
1352
  self._record('output.weight', hidden)
1353
 
1354
  return hidden
 
673
  for name, arr, cnt in imp_refs:
674
  self.importance[name] = (arr.astype(np.float64), cnt.value)
675
 
676
+ # Force-free per-layer weight buffers (~1.4 GB) before next layer
677
+ del refs, imp_refs
678
+ import gc; gc.collect()
679
+
680
  return hidden
681
 
682
  def _hpc_rms_norm(self, x, weight, eps):
 
1321
  if embed_w is None:
1322
  raise RuntimeError("Missing token_embd.weight")
1323
 
1324
+ hidden = embed_w[token_ids].copy() # [seq_len, n_embd]
1325
+ del embed_w # Free ~5 GB embedding table before layer loop
1326
 
1327
  # RoPE frequencies
1328
  cos_f, sin_f = rope_freqs(self.head_dim, seq_len, cfg['rope_base'])
 
1351
  if self.verbose and (layer_idx + 1) % 4 == 0:
1352
  print(f" Layer {layer_idx + 1}/{cfg['n_layers']}", end='\r')
1353
 
1354
+ # Output projection — check existence without loading the full 5 GB tensor
1355
+ if 'output.weight' in self.model.tensor_infos:
 
1356
  self._record('output.weight', hidden)
1357
 
1358
  return hidden