visiontest / inspect_yolo_model.py
tarto2's picture
Upload folder using huggingface_hub
e4189f9 verified
"""
Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge).
"""
import argparse
from pathlib import Path
import torch
from ultralytics import YOLO
def inspect_yolo_model(model_path: Path):
"""Inspect YOLO model to determine variant and architecture details."""
print(f"Inspecting model: {model_path}")
print("=" * 60)
# Method 1: Load with Ultralytics and check metadata
try:
model = YOLO(str(model_path))
# Check model info
print("\n--- Model Information ---")
print(f"Model type: {type(model.model)}")
# Try to get model name from metadata
if hasattr(model, 'model') and hasattr(model.model, 'yaml'):
yaml_path = model.model.yaml
print(f"YAML config: {yaml_path}")
if yaml_path:
# Extract variant from yaml path
yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path)
print(f"YAML name: {yaml_name}")
# Common patterns: yolo11n.yaml, yolo11s.yaml, yolo11m.yaml, yolo11l.yaml, yolo11x.yaml
# or yolov8n.yaml, yolov8s.yaml, etc.
if 'n' in yaml_name.lower():
variant = "Nano (n)"
elif 's' in yaml_name.lower():
variant = "Small (s)"
elif 'm' in yaml_name.lower():
variant = "Medium (m)"
elif 'l' in yaml_name.lower():
variant = "Large (l)"
elif 'x' in yaml_name.lower():
variant = "XLarge (x)"
else:
variant = "Unknown"
print(f"Detected variant: {variant}")
# Check model metadata if available
if hasattr(model.model, 'names'):
print(f"Number of classes: {len(model.model.names)}")
print(f"Class names: {list(model.model.names.values())[:5]}...") # Show first 5
# Get model info summary
print("\n--- Model Summary ---")
try:
info = model.info(verbose=False)
print(info)
except:
pass
# Count parameters
if hasattr(model.model, 'parameters'):
total_params = sum(p.numel() for p in model.model.parameters())
trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
print(f"\n--- Parameter Count ---")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Rough estimates for YOLO variants (these vary by version but give a ballpark)
if total_params < 3_000_000:
size_estimate = "Nano (n) - typically < 3M params"
elif total_params < 12_000_000:
size_estimate = "Small (s) - typically 3-12M params"
elif total_params < 26_000_000:
size_estimate = "Medium (m) - typically 12-26M params"
elif total_params < 44_000_000:
size_estimate = "Large (l) - typically 26-44M params"
else:
size_estimate = "XLarge (x) - typically > 44M params"
print(f"Size estimate: {size_estimate}")
except Exception as e:
print(f"Error loading with Ultralytics: {e}")
print("\nTrying alternative method...")
# Method 2: Direct PyTorch inspection
print("\n" + "=" * 60)
print("--- Direct PyTorch Inspection ---")
try:
checkpoint = torch.load(str(model_path), map_location='cpu')
# Check for metadata
if 'model' in checkpoint:
model_dict = checkpoint['model']
if isinstance(model_dict, dict):
# Look for architecture hints in state dict keys
print("Checking state dict keys for architecture hints...")
keys = list(model_dict.keys())[:10] # First 10 keys
for key in keys:
print(f" {key}")
# Count layers
layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k])
print(f"\nTotal weight/bias tensors: {layer_count}")
# Check checkpoint metadata
if 'epoch' in checkpoint:
print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}")
if 'best_fitness' in checkpoint:
print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}")
# File size
file_size_mb = model_path.stat().st_size / (1024 * 1024)
print(f"\nModel file size: {file_size_mb:.2f} MB")
# Rough size estimates based on file size (very approximate)
if file_size_mb < 6:
size_estimate = "Likely Nano (n) - file < 6MB"
elif file_size_mb < 22:
size_estimate = "Likely Small (s) - file 6-22MB"
elif file_size_mb < 50:
size_estimate = "Likely Medium (m) - file 22-50MB"
elif file_size_mb < 85:
size_estimate = "Likely Large (l) - file 50-85MB"
else:
size_estimate = "Likely XLarge (x) - file > 85MB"
print(f"Size estimate from file: {size_estimate}")
except Exception as e:
print(f"Error with direct PyTorch inspection: {e}")
print("\n" + "=" * 60)
print("Inspection complete!")
def main():
parser = argparse.ArgumentParser(
description="Inspect YOLO .pt model to determine variant"
)
parser.add_argument(
"--model_path",
type=Path,
help="Path to YOLO .pt model file"
)
args = parser.parse_args()
if not args.model_path.exists():
print(f"Error: Model file not found: {args.model_path}")
return
inspect_yolo_model(args.model_path)
if __name__ == "__main__":
main()