model_tools / weight_counter.py
Naphula's picture
Upload 5 files
7080631 verified
import yaml
import sys
def main():
if len(sys.argv) < 2:
print("Usage: python weight_counter.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"Error reading file: {e}")
sys.exit(1)
total_weight = 0.0
count = 0
print(f"Scanning: {config_path}...")
if 'models' in config:
for m in config['models']:
# Safely get parameters -> weight, default to 0 if missing
params = m.get('parameters', {})
weight = params.get('weight', 0)
# Only sum if it's a number (ignores list/gradient weights)
if isinstance(weight, (int, float)):
total_weight += weight
count += 1
else:
print(f" [!] Skipped non-scalar weight for: {m.get('model')}")
print("-" * 30)
print(f"Models Counted: {count}")
print(f"Total Weight Sum: {total_weight}")
print("-" * 30)
if __name__ == "__main__":
main()