CharlesCNorton
commited on
Commit
·
9c3e925
1
Parent(s):
1721000
Add flexible safetensors file loading
Browse files- Add --weights argument to specify custom .safetensors file
- Auto-discover any .safetensors file if model.safetensors not found
- Update discover_circuits to accept any .safetensors file
prune.py
CHANGED
|
@@ -495,9 +495,10 @@ class ComputationGraph:
|
|
| 495 |
class AdaptiveCircuit:
|
| 496 |
"""Adaptive threshold circuit with automatic batched evaluation."""
|
| 497 |
|
| 498 |
-
def __init__(self, path: Path, device: str = 'cuda'):
|
| 499 |
self.path = Path(path)
|
| 500 |
self.device = device
|
|
|
|
| 501 |
self.spec = self._load_spec()
|
| 502 |
self.weights = self._load_weights()
|
| 503 |
self.weight_keys = list(self.weights.keys())
|
|
@@ -655,7 +656,17 @@ class AdaptiveCircuit:
|
|
| 655 |
)
|
| 656 |
|
| 657 |
def _load_weights(self) -> Dict[str, torch.Tensor]:
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
return {k: v.float().to(self.device) for k, v in w.items()}
|
| 660 |
|
| 661 |
def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -3010,7 +3021,7 @@ def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
|
|
| 3010 |
"""Find all circuits."""
|
| 3011 |
circuits = []
|
| 3012 |
for d in base.iterdir():
|
| 3013 |
-
if d.is_dir() and (d / 'config.json').exists() and (d
|
| 3014 |
try:
|
| 3015 |
with open(d / 'config.json') as f:
|
| 3016 |
cfg = json.load(f)
|
|
@@ -3105,6 +3116,7 @@ def run_pipeline(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Con
|
|
| 3105 |
def main():
|
| 3106 |
parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v4')
|
| 3107 |
parser.add_argument('circuit', nargs='?', help='Circuit name')
|
|
|
|
| 3108 |
parser.add_argument('--list', action='store_true')
|
| 3109 |
parser.add_argument('--all', action='store_true')
|
| 3110 |
parser.add_argument('--max-inputs', type=int, default=10)
|
|
@@ -3198,7 +3210,7 @@ def main():
|
|
| 3198 |
print(f"Circuit not found: {args.circuit}")
|
| 3199 |
return
|
| 3200 |
|
| 3201 |
-
circuit = AdaptiveCircuit(path, cfg.device)
|
| 3202 |
|
| 3203 |
if args.pipeline:
|
| 3204 |
stages = [s.strip() for s in args.pipeline.split(',')]
|
|
|
|
| 495 |
class AdaptiveCircuit:
|
| 496 |
"""Adaptive threshold circuit with automatic batched evaluation."""
|
| 497 |
|
| 498 |
+
def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None):
|
| 499 |
self.path = Path(path)
|
| 500 |
self.device = device
|
| 501 |
+
self.weights_file = weights_file
|
| 502 |
self.spec = self._load_spec()
|
| 503 |
self.weights = self._load_weights()
|
| 504 |
self.weight_keys = list(self.weights.keys())
|
|
|
|
| 656 |
)
|
| 657 |
|
| 658 |
def _load_weights(self) -> Dict[str, torch.Tensor]:
|
| 659 |
+
if self.weights_file:
|
| 660 |
+
safetensors_file = self.path / self.weights_file
|
| 661 |
+
else:
|
| 662 |
+
safetensors_file = self.path / 'model.safetensors'
|
| 663 |
+
if not safetensors_file.exists():
|
| 664 |
+
candidates = list(self.path.glob('*.safetensors'))
|
| 665 |
+
if candidates:
|
| 666 |
+
safetensors_file = candidates[0]
|
| 667 |
+
else:
|
| 668 |
+
raise FileNotFoundError(f"No .safetensors file found in {self.path}")
|
| 669 |
+
w = load_file(str(safetensors_file))
|
| 670 |
return {k: v.float().to(self.device) for k, v in w.items()}
|
| 671 |
|
| 672 |
def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 3021 |
"""Find all circuits."""
|
| 3022 |
circuits = []
|
| 3023 |
for d in base.iterdir():
|
| 3024 |
+
if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')):
|
| 3025 |
try:
|
| 3026 |
with open(d / 'config.json') as f:
|
| 3027 |
cfg = json.load(f)
|
|
|
|
| 3116 |
def main():
|
| 3117 |
parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v4')
|
| 3118 |
parser.add_argument('circuit', nargs='?', help='Circuit name')
|
| 3119 |
+
parser.add_argument('--weights', type=str, help='Specific .safetensors file to load (default: model.safetensors or first found)')
|
| 3120 |
parser.add_argument('--list', action='store_true')
|
| 3121 |
parser.add_argument('--all', action='store_true')
|
| 3122 |
parser.add_argument('--max-inputs', type=int, default=10)
|
|
|
|
| 3210 |
print(f"Circuit not found: {args.circuit}")
|
| 3211 |
return
|
| 3212 |
|
| 3213 |
+
circuit = AdaptiveCircuit(path, cfg.device, args.weights)
|
| 3214 |
|
| 3215 |
if args.pipeline:
|
| 3216 |
stages = [s.strip() for s in args.pipeline.split(',')]
|