File size: 6,147 Bytes
e4189f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge).
"""
import argparse
from pathlib import Path
import torch
from ultralytics import YOLO
def inspect_yolo_model(model_path: Path):
"""Inspect YOLO model to determine variant and architecture details."""
print(f"Inspecting model: {model_path}")
print("=" * 60)
# Method 1: Load with Ultralytics and check metadata
try:
model = YOLO(str(model_path))
# Check model info
print("\n--- Model Information ---")
print(f"Model type: {type(model.model)}")
# Try to get model name from metadata
if hasattr(model, 'model') and hasattr(model.model, 'yaml'):
yaml_path = model.model.yaml
print(f"YAML config: {yaml_path}")
if yaml_path:
# Extract variant from yaml path
yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path)
print(f"YAML name: {yaml_name}")
# Common patterns: yolo11n.yaml, yolo11s.yaml, yolo11m.yaml, yolo11l.yaml, yolo11x.yaml
# or yolov8n.yaml, yolov8s.yaml, etc.
if 'n' in yaml_name.lower():
variant = "Nano (n)"
elif 's' in yaml_name.lower():
variant = "Small (s)"
elif 'm' in yaml_name.lower():
variant = "Medium (m)"
elif 'l' in yaml_name.lower():
variant = "Large (l)"
elif 'x' in yaml_name.lower():
variant = "XLarge (x)"
else:
variant = "Unknown"
print(f"Detected variant: {variant}")
# Check model metadata if available
if hasattr(model.model, 'names'):
print(f"Number of classes: {len(model.model.names)}")
print(f"Class names: {list(model.model.names.values())[:5]}...") # Show first 5
# Get model info summary
print("\n--- Model Summary ---")
try:
info = model.info(verbose=False)
print(info)
except:
pass
# Count parameters
if hasattr(model.model, 'parameters'):
total_params = sum(p.numel() for p in model.model.parameters())
trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
print(f"\n--- Parameter Count ---")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Rough estimates for YOLO variants (these vary by version but give a ballpark)
if total_params < 3_000_000:
size_estimate = "Nano (n) - typically < 3M params"
elif total_params < 12_000_000:
size_estimate = "Small (s) - typically 3-12M params"
elif total_params < 26_000_000:
size_estimate = "Medium (m) - typically 12-26M params"
elif total_params < 44_000_000:
size_estimate = "Large (l) - typically 26-44M params"
else:
size_estimate = "XLarge (x) - typically > 44M params"
print(f"Size estimate: {size_estimate}")
except Exception as e:
print(f"Error loading with Ultralytics: {e}")
print("\nTrying alternative method...")
# Method 2: Direct PyTorch inspection
print("\n" + "=" * 60)
print("--- Direct PyTorch Inspection ---")
try:
checkpoint = torch.load(str(model_path), map_location='cpu')
# Check for metadata
if 'model' in checkpoint:
model_dict = checkpoint['model']
if isinstance(model_dict, dict):
# Look for architecture hints in state dict keys
print("Checking state dict keys for architecture hints...")
keys = list(model_dict.keys())[:10] # First 10 keys
for key in keys:
print(f" {key}")
# Count layers
layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k])
print(f"\nTotal weight/bias tensors: {layer_count}")
# Check checkpoint metadata
if 'epoch' in checkpoint:
print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}")
if 'best_fitness' in checkpoint:
print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}")
# File size
file_size_mb = model_path.stat().st_size / (1024 * 1024)
print(f"\nModel file size: {file_size_mb:.2f} MB")
# Rough size estimates based on file size (very approximate)
if file_size_mb < 6:
size_estimate = "Likely Nano (n) - file < 6MB"
elif file_size_mb < 22:
size_estimate = "Likely Small (s) - file 6-22MB"
elif file_size_mb < 50:
size_estimate = "Likely Medium (m) - file 22-50MB"
elif file_size_mb < 85:
size_estimate = "Likely Large (l) - file 50-85MB"
else:
size_estimate = "Likely XLarge (x) - file > 85MB"
print(f"Size estimate from file: {size_estimate}")
except Exception as e:
print(f"Error with direct PyTorch inspection: {e}")
print("\n" + "=" * 60)
print("Inspection complete!")
def main():
parser = argparse.ArgumentParser(
description="Inspect YOLO .pt model to determine variant"
)
parser.add_argument(
"--model_path",
type=Path,
help="Path to YOLO .pt model file"
)
args = parser.parse_args()
if not args.model_path.exists():
print(f"Error: Model file not found: {args.model_path}")
return
inspect_yolo_model(args.model_path)
if __name__ == "__main__":
main()
|