""" Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge). """ import argparse from pathlib import Path import torch from ultralytics import YOLO def inspect_yolo_model(model_path: Path): """Inspect YOLO model to determine variant and architecture details.""" print(f"Inspecting model: {model_path}") print("=" * 60) # Method 1: Load with Ultralytics and check metadata try: model = YOLO(str(model_path)) # Check model info print("\n--- Model Information ---") print(f"Model type: {type(model.model)}") # Try to get model name from metadata if hasattr(model, 'model') and hasattr(model.model, 'yaml'): yaml_path = model.model.yaml print(f"YAML config: {yaml_path}") if yaml_path: # Extract variant from yaml path yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path) print(f"YAML name: {yaml_name}") # Common patterns: yolo11n.yaml, yolo11s.yaml, yolo11m.yaml, yolo11l.yaml, yolo11x.yaml # or yolov8n.yaml, yolov8s.yaml, etc. if 'n' in yaml_name.lower(): variant = "Nano (n)" elif 's' in yaml_name.lower(): variant = "Small (s)" elif 'm' in yaml_name.lower(): variant = "Medium (m)" elif 'l' in yaml_name.lower(): variant = "Large (l)" elif 'x' in yaml_name.lower(): variant = "XLarge (x)" else: variant = "Unknown" print(f"Detected variant: {variant}") # Check model metadata if available if hasattr(model.model, 'names'): print(f"Number of classes: {len(model.model.names)}") print(f"Class names: {list(model.model.names.values())[:5]}...") # Show first 5 # Get model info summary print("\n--- Model Summary ---") try: info = model.info(verbose=False) print(info) except: pass # Count parameters if hasattr(model.model, 'parameters'): total_params = sum(p.numel() for p in model.model.parameters()) trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad) print(f"\n--- Parameter Count ---") print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Rough estimates for YOLO variants (these vary by version but give a ballpark) if total_params < 3_000_000: size_estimate = "Nano (n) - typically < 3M params" elif total_params < 12_000_000: size_estimate = "Small (s) - typically 3-12M params" elif total_params < 26_000_000: size_estimate = "Medium (m) - typically 12-26M params" elif total_params < 44_000_000: size_estimate = "Large (l) - typically 26-44M params" else: size_estimate = "XLarge (x) - typically > 44M params" print(f"Size estimate: {size_estimate}") except Exception as e: print(f"Error loading with Ultralytics: {e}") print("\nTrying alternative method...") # Method 2: Direct PyTorch inspection print("\n" + "=" * 60) print("--- Direct PyTorch Inspection ---") try: checkpoint = torch.load(str(model_path), map_location='cpu') # Check for metadata if 'model' in checkpoint: model_dict = checkpoint['model'] if isinstance(model_dict, dict): # Look for architecture hints in state dict keys print("Checking state dict keys for architecture hints...") keys = list(model_dict.keys())[:10] # First 10 keys for key in keys: print(f" {key}") # Count layers layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k]) print(f"\nTotal weight/bias tensors: {layer_count}") # Check checkpoint metadata if 'epoch' in checkpoint: print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}") if 'best_fitness' in checkpoint: print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}") # File size file_size_mb = model_path.stat().st_size / (1024 * 1024) print(f"\nModel file size: {file_size_mb:.2f} MB") # Rough size estimates based on file size (very approximate) if file_size_mb < 6: size_estimate = "Likely Nano (n) - file < 6MB" elif file_size_mb < 22: size_estimate = "Likely Small (s) - file 6-22MB" elif file_size_mb < 50: size_estimate = "Likely Medium (m) - file 22-50MB" elif file_size_mb < 85: size_estimate = "Likely Large (l) - file 50-85MB" else: size_estimate = "Likely XLarge (x) - file > 85MB" print(f"Size estimate from file: {size_estimate}") except Exception as e: print(f"Error with direct PyTorch inspection: {e}") print("\n" + "=" * 60) print("Inspection complete!") def main(): parser = argparse.ArgumentParser( description="Inspect YOLO .pt model to determine variant" ) parser.add_argument( "--model_path", type=Path, help="Path to YOLO .pt model file" ) args = parser.parse_args() if not args.model_path.exists(): print(f"Error: Model file not found: {args.model_path}") return inspect_yolo_model(args.model_path) if __name__ == "__main__": main()