File size: 6,147 Bytes
e4189f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""

Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge).

"""
import argparse
from pathlib import Path
import torch
from ultralytics import YOLO


def inspect_yolo_model(model_path: Path):
    """Inspect YOLO model to determine variant and architecture details."""
    print(f"Inspecting model: {model_path}")
    print("=" * 60)
    
    # Method 1: Load with Ultralytics and check metadata
    try:
        model = YOLO(str(model_path))
        
        # Check model info
        print("\n--- Model Information ---")
        print(f"Model type: {type(model.model)}")
        
        # Try to get model name from metadata
        if hasattr(model, 'model') and hasattr(model.model, 'yaml'):
            yaml_path = model.model.yaml
            print(f"YAML config: {yaml_path}")
            if yaml_path:
                # Extract variant from yaml path
                yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path)
                print(f"YAML name: {yaml_name}")
                # Common patterns: yolo11n.yaml, yolo11s.yaml, yolo11m.yaml, yolo11l.yaml, yolo11x.yaml
                # or yolov8n.yaml, yolov8s.yaml, etc.
                if 'n' in yaml_name.lower():
                    variant = "Nano (n)"
                elif 's' in yaml_name.lower():
                    variant = "Small (s)"
                elif 'm' in yaml_name.lower():
                    variant = "Medium (m)"
                elif 'l' in yaml_name.lower():
                    variant = "Large (l)"
                elif 'x' in yaml_name.lower():
                    variant = "XLarge (x)"
                else:
                    variant = "Unknown"
                print(f"Detected variant: {variant}")
        
        # Check model metadata if available
        if hasattr(model.model, 'names'):
            print(f"Number of classes: {len(model.model.names)}")
            print(f"Class names: {list(model.model.names.values())[:5]}...")  # Show first 5
        
        # Get model info summary
        print("\n--- Model Summary ---")
        try:
            info = model.info(verbose=False)
            print(info)
        except:
            pass
        
        # Count parameters
        if hasattr(model.model, 'parameters'):
            total_params = sum(p.numel() for p in model.model.parameters())
            trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
            print(f"\n--- Parameter Count ---")
            print(f"Total parameters: {total_params:,}")
            print(f"Trainable parameters: {trainable_params:,}")
            
            # Rough estimates for YOLO variants (these vary by version but give a ballpark)
            if total_params < 3_000_000:
                size_estimate = "Nano (n) - typically < 3M params"
            elif total_params < 12_000_000:
                size_estimate = "Small (s) - typically 3-12M params"
            elif total_params < 26_000_000:
                size_estimate = "Medium (m) - typically 12-26M params"
            elif total_params < 44_000_000:
                size_estimate = "Large (l) - typically 26-44M params"
            else:
                size_estimate = "XLarge (x) - typically > 44M params"
            print(f"Size estimate: {size_estimate}")
        
    except Exception as e:
        print(f"Error loading with Ultralytics: {e}")
        print("\nTrying alternative method...")
    
    # Method 2: Direct PyTorch inspection
    print("\n" + "=" * 60)
    print("--- Direct PyTorch Inspection ---")
    try:
        checkpoint = torch.load(str(model_path), map_location='cpu')
        
        # Check for metadata
        if 'model' in checkpoint:
            model_dict = checkpoint['model']
            if isinstance(model_dict, dict):
                # Look for architecture hints in state dict keys
                print("Checking state dict keys for architecture hints...")
                keys = list(model_dict.keys())[:10]  # First 10 keys
                for key in keys:
                    print(f"  {key}")
                
                # Count layers
                layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k])
                print(f"\nTotal weight/bias tensors: {layer_count}")
        
        # Check checkpoint metadata
        if 'epoch' in checkpoint:
            print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}")
        if 'best_fitness' in checkpoint:
            print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}")
        
        # File size
        file_size_mb = model_path.stat().st_size / (1024 * 1024)
        print(f"\nModel file size: {file_size_mb:.2f} MB")
        
        # Rough size estimates based on file size (very approximate)
        if file_size_mb < 6:
            size_estimate = "Likely Nano (n) - file < 6MB"
        elif file_size_mb < 22:
            size_estimate = "Likely Small (s) - file 6-22MB"
        elif file_size_mb < 50:
            size_estimate = "Likely Medium (m) - file 22-50MB"
        elif file_size_mb < 85:
            size_estimate = "Likely Large (l) - file 50-85MB"
        else:
            size_estimate = "Likely XLarge (x) - file > 85MB"
        print(f"Size estimate from file: {size_estimate}")
        
    except Exception as e:
        print(f"Error with direct PyTorch inspection: {e}")
    
    print("\n" + "=" * 60)
    print("Inspection complete!")


def main():
    parser = argparse.ArgumentParser(
        description="Inspect YOLO .pt model to determine variant"
    )
    parser.add_argument(
        "--model_path",
        type=Path,
        help="Path to YOLO .pt model file"
    )
    args = parser.parse_args()
    
    if not args.model_path.exists():
        print(f"Error: Model file not found: {args.model_path}")
        return
    
    inspect_yolo_model(args.model_path)


if __name__ == "__main__":
    main()