| """
|
| Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge).
|
| """
|
| import argparse
|
| from pathlib import Path
|
| import torch
|
| from ultralytics import YOLO
|
|
|
|
|
| def inspect_yolo_model(model_path: Path):
|
| """Inspect YOLO model to determine variant and architecture details."""
|
| print(f"Inspecting model: {model_path}")
|
| print("=" * 60)
|
|
|
|
|
| try:
|
| model = YOLO(str(model_path))
|
|
|
|
|
| print("\n--- Model Information ---")
|
| print(f"Model type: {type(model.model)}")
|
|
|
|
|
| if hasattr(model, 'model') and hasattr(model.model, 'yaml'):
|
| yaml_path = model.model.yaml
|
| print(f"YAML config: {yaml_path}")
|
| if yaml_path:
|
|
|
| yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path)
|
| print(f"YAML name: {yaml_name}")
|
|
|
|
|
| if 'n' in yaml_name.lower():
|
| variant = "Nano (n)"
|
| elif 's' in yaml_name.lower():
|
| variant = "Small (s)"
|
| elif 'm' in yaml_name.lower():
|
| variant = "Medium (m)"
|
| elif 'l' in yaml_name.lower():
|
| variant = "Large (l)"
|
| elif 'x' in yaml_name.lower():
|
| variant = "XLarge (x)"
|
| else:
|
| variant = "Unknown"
|
| print(f"Detected variant: {variant}")
|
|
|
|
|
| if hasattr(model.model, 'names'):
|
| print(f"Number of classes: {len(model.model.names)}")
|
| print(f"Class names: {list(model.model.names.values())[:5]}...")
|
|
|
|
|
| print("\n--- Model Summary ---")
|
| try:
|
| info = model.info(verbose=False)
|
| print(info)
|
| except:
|
| pass
|
|
|
|
|
| if hasattr(model.model, 'parameters'):
|
| total_params = sum(p.numel() for p in model.model.parameters())
|
| trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
| print(f"\n--- Parameter Count ---")
|
| print(f"Total parameters: {total_params:,}")
|
| print(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
| if total_params < 3_000_000:
|
| size_estimate = "Nano (n) - typically < 3M params"
|
| elif total_params < 12_000_000:
|
| size_estimate = "Small (s) - typically 3-12M params"
|
| elif total_params < 26_000_000:
|
| size_estimate = "Medium (m) - typically 12-26M params"
|
| elif total_params < 44_000_000:
|
| size_estimate = "Large (l) - typically 26-44M params"
|
| else:
|
| size_estimate = "XLarge (x) - typically > 44M params"
|
| print(f"Size estimate: {size_estimate}")
|
|
|
| except Exception as e:
|
| print(f"Error loading with Ultralytics: {e}")
|
| print("\nTrying alternative method...")
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("--- Direct PyTorch Inspection ---")
|
| try:
|
| checkpoint = torch.load(str(model_path), map_location='cpu')
|
|
|
|
|
| if 'model' in checkpoint:
|
| model_dict = checkpoint['model']
|
| if isinstance(model_dict, dict):
|
|
|
| print("Checking state dict keys for architecture hints...")
|
| keys = list(model_dict.keys())[:10]
|
| for key in keys:
|
| print(f" {key}")
|
|
|
|
|
| layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k])
|
| print(f"\nTotal weight/bias tensors: {layer_count}")
|
|
|
|
|
| if 'epoch' in checkpoint:
|
| print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}")
|
| if 'best_fitness' in checkpoint:
|
| print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}")
|
|
|
|
|
| file_size_mb = model_path.stat().st_size / (1024 * 1024)
|
| print(f"\nModel file size: {file_size_mb:.2f} MB")
|
|
|
|
|
| if file_size_mb < 6:
|
| size_estimate = "Likely Nano (n) - file < 6MB"
|
| elif file_size_mb < 22:
|
| size_estimate = "Likely Small (s) - file 6-22MB"
|
| elif file_size_mb < 50:
|
| size_estimate = "Likely Medium (m) - file 22-50MB"
|
| elif file_size_mb < 85:
|
| size_estimate = "Likely Large (l) - file 50-85MB"
|
| else:
|
| size_estimate = "Likely XLarge (x) - file > 85MB"
|
| print(f"Size estimate from file: {size_estimate}")
|
|
|
| except Exception as e:
|
| print(f"Error with direct PyTorch inspection: {e}")
|
|
|
| print("\n" + "=" * 60)
|
| print("Inspection complete!")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(
|
| description="Inspect YOLO .pt model to determine variant"
|
| )
|
| parser.add_argument(
|
| "--model_path",
|
| type=Path,
|
| help="Path to YOLO .pt model file"
|
| )
|
| args = parser.parse_args()
|
|
|
| if not args.model_path.exists():
|
| print(f"Error: Model file not found: {args.model_path}")
|
| return
|
|
|
| inspect_yolo_model(args.model_path)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|
|
|