CharlesCNorton commited on
Commit
9c3e925
·
1 Parent(s): 1721000

Add flexible safetensors file loading

Browse files

- Add --weights argument to specify custom .safetensors file
- Auto-discover any .safetensors file if model.safetensors not found
- Update discover_circuits to accept any .safetensors file

Files changed (1) hide show
  1. prune.py +16 -4
prune.py CHANGED
@@ -495,9 +495,10 @@ class ComputationGraph:
495
  class AdaptiveCircuit:
496
  """Adaptive threshold circuit with automatic batched evaluation."""
497
 
498
- def __init__(self, path: Path, device: str = 'cuda'):
499
  self.path = Path(path)
500
  self.device = device
 
501
  self.spec = self._load_spec()
502
  self.weights = self._load_weights()
503
  self.weight_keys = list(self.weights.keys())
@@ -655,7 +656,17 @@ class AdaptiveCircuit:
655
  )
656
 
657
  def _load_weights(self) -> Dict[str, torch.Tensor]:
658
- w = load_file(str(self.path / 'model.safetensors'))
 
 
 
 
 
 
 
 
 
 
659
  return {k: v.float().to(self.device) for k, v in w.items()}
660
 
661
  def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -3010,7 +3021,7 @@ def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
3010
  """Find all circuits."""
3011
  circuits = []
3012
  for d in base.iterdir():
3013
- if d.is_dir() and (d / 'config.json').exists() and (d / 'model.safetensors').exists():
3014
  try:
3015
  with open(d / 'config.json') as f:
3016
  cfg = json.load(f)
@@ -3105,6 +3116,7 @@ def run_pipeline(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Con
3105
  def main():
3106
  parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v4')
3107
  parser.add_argument('circuit', nargs='?', help='Circuit name')
 
3108
  parser.add_argument('--list', action='store_true')
3109
  parser.add_argument('--all', action='store_true')
3110
  parser.add_argument('--max-inputs', type=int, default=10)
@@ -3198,7 +3210,7 @@ def main():
3198
  print(f"Circuit not found: {args.circuit}")
3199
  return
3200
 
3201
- circuit = AdaptiveCircuit(path, cfg.device)
3202
 
3203
  if args.pipeline:
3204
  stages = [s.strip() for s in args.pipeline.split(',')]
 
495
  class AdaptiveCircuit:
496
  """Adaptive threshold circuit with automatic batched evaluation."""
497
 
498
+ def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None):
499
  self.path = Path(path)
500
  self.device = device
501
+ self.weights_file = weights_file
502
  self.spec = self._load_spec()
503
  self.weights = self._load_weights()
504
  self.weight_keys = list(self.weights.keys())
 
656
  )
657
 
658
  def _load_weights(self) -> Dict[str, torch.Tensor]:
659
+ if self.weights_file:
660
+ safetensors_file = self.path / self.weights_file
661
+ else:
662
+ safetensors_file = self.path / 'model.safetensors'
663
+ if not safetensors_file.exists():
664
+ candidates = list(self.path.glob('*.safetensors'))
665
+ if candidates:
666
+ safetensors_file = candidates[0]
667
+ else:
668
+ raise FileNotFoundError(f"No .safetensors file found in {self.path}")
669
+ w = load_file(str(safetensors_file))
670
  return {k: v.float().to(self.device) for k, v in w.items()}
671
 
672
  def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
 
3021
  """Find all circuits."""
3022
  circuits = []
3023
  for d in base.iterdir():
3024
+ if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')):
3025
  try:
3026
  with open(d / 'config.json') as f:
3027
  cfg = json.load(f)
 
3116
  def main():
3117
  parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v4')
3118
  parser.add_argument('circuit', nargs='?', help='Circuit name')
3119
+ parser.add_argument('--weights', type=str, help='Specific .safetensors file to load (default: model.safetensors or first found)')
3120
  parser.add_argument('--list', action='store_true')
3121
  parser.add_argument('--all', action='store_true')
3122
  parser.add_argument('--max-inputs', type=int, default=10)
 
3210
  print(f"Circuit not found: {args.circuit}")
3211
  return
3212
 
3213
+ circuit = AdaptiveCircuit(path, cfg.device, args.weights)
3214
 
3215
  if args.pipeline:
3216
  stages = [s.strip() for s in args.pipeline.split(',')]