File size: 7,432 Bytes
4e78d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
from torchvision import models
import os

class AIDetectorModel(nn.Module):
    def __init__(self, num_classes=2, dropout_prob=0.3, load_pretrained=True):
        super(AIDetectorModel, self).__init__()
        
        if load_pretrained:
            print("πŸ“₯ Loading RegNet with pre-trained weights...")
            self.backbone = models.regnet_y_16gf(weights=models.RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1)
        else:
            self.backbone = models.regnet_y_16gf(weights=None)
        
        # Freeze most layers
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        # Unfreeze last block
        if hasattr(self.backbone, 'trunk_output') and hasattr(self.backbone.trunk_output, 'block4'):
            self.backbone.trunk_output.block4.requires_grad_(True)
        
        # Replace average pooling with max pooling
        self.backbone.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1))
        
        # Get feature dimension
        num_ftrs = self.backbone.fc.in_features
        
        # Replace classifier
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_ftrs, 2048),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(2048, 1024),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(1024, 512),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

def verify_complete_backbone(model):
    """Verify that model has complete RegNet backbone - CORRECTED"""
    state_dict = model.state_dict()
    all_keys = list(state_dict.keys())
    
    # Check for actual RegNet components based on real structure
    stem_keys = [k for k in all_keys if 'backbone.stem' in k]
    block_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k]
    proj_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k and 'proj' in k]
    fc_keys = [k for k in all_keys if 'backbone.fc' in k]
    
    print(f"πŸ” Backbone verification:")
    print(f"   β€’ Total keys: {len(all_keys)}")
    print(f"   β€’ Stem keys: {len(stem_keys)}")
    print(f"   β€’ Block keys: {len(block_keys)}")
    print(f"   β€’ Projection keys: {len(proj_keys)}")
    print(f"   β€’ FC keys: {len(fc_keys)}")
    
    # CORRECTED: RegNet Y 16GF has: ~420 block params, ~24 proj layers, 6 stem params
    is_complete = (
        len(block_keys) > 400 and  # Should have ~420 block parameters
        len(stem_keys) > 5 and     # Should have 6 stem parameters  
        len(proj_keys) > 20        # Should have ~24 projection parameters
    )
    
    if is_complete:
        print("βœ… Complete backbone verified!")
        print(f"βœ… Found {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj layers")
    else:
        print("❌ Incomplete backbone detected!")
        print(f"Expected: >400 blocks, >5 stem, >20 proj")
        print(f"Found: {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj")
    
    return is_complete

def create_truly_complete_model(trained_checkpoint_path, output_path):
    """Create a truly complete model with full RegNet weights"""
    
    print("πŸ”„ Creating TRULY complete offline model...")
    print("=" * 60)
    
    # Step 1: Create fresh model WITH pretrained weights
    print("πŸ“₯ Step 1: Loading fresh RegNet with ImageNet weights...")
    fresh_model = AIDetectorModel(load_pretrained=True)
    
    # Step 2: Verify fresh model is complete (FIXED VALIDATION)
    print("\nπŸ” Step 2: Verifying fresh model completeness...")
    if not verify_complete_backbone(fresh_model):
        raise RuntimeError("❌ Fresh model is not complete!")
    
    # Step 3: Load your trained weights
    print(f"\nπŸ“‚ Step 3: Loading your trained weights from {trained_checkpoint_path}...")
    trained_checkpoint = torch.load(trained_checkpoint_path, map_location='cpu')
    
    if isinstance(trained_checkpoint, dict) and 'model_state_dict' in trained_checkpoint:
        trained_state = trained_checkpoint['model_state_dict']
        extra_info = {
            'num_classes': trained_checkpoint.get('num_classes', 2),
            'dropout_prob': trained_checkpoint.get('dropout_prob', 0.3),
            'epoch': trained_checkpoint.get('epoch', 0),
            'val_loss': trained_checkpoint.get('val_loss', 0.0)
        }
    else:
        trained_state = trained_checkpoint
        extra_info = {'num_classes': 2, 'dropout_prob': 0.3}
    
    # Step 4: Merge weights intelligently
    print("\nπŸ”€ Step 4: Merging backbone + trained classifier...")
    fresh_state = fresh_model.state_dict()
    
    # Keep ALL backbone weights from fresh model (includes RegNet pretrained)
    # Replace ONLY classifier weights from trained model
    merged_state = fresh_state.copy()
    
    classifier_keys = [k for k in trained_state.keys() if 'backbone.fc' in k]
    print(f"   β€’ Replacing {len(classifier_keys)} classifier parameters")
    
    for key in classifier_keys:
        if key in trained_state:
            merged_state[key] = trained_state[key]
            print(f"   βœ… Updated: {key}")
    
    # Step 5: Create final model and verify
    print("\nβœ… Step 5: Creating final complete model...")
    final_model = AIDetectorModel(load_pretrained=False)  # Empty architecture
    final_model.load_state_dict(merged_state)  # Load merged weights
    
    # Final verification (should definitely pass now)
    print("\nπŸ” Final verification:")
    if not verify_complete_backbone(final_model):
        raise RuntimeError("❌ Final model verification failed!")
    
    # Step 6: Save complete model
    print(f"\nπŸ’Ύ Step 6: Saving complete model to {output_path}...")
    complete_checkpoint = {
        'model_state_dict': merged_state,
        'complete_model': True,
        'backbone_included': True,
        'offline_ready': True,
        **extra_info
    }
    
    torch.save(complete_checkpoint, output_path)
    
    file_size = os.path.getsize(output_path) / (1024*1024)
    print(f"βœ… SUCCESS! Complete model saved:")
    print(f"   β€’ File: {output_path}")
    print(f"   β€’ Size: {file_size:.1f}MB")
    print(f"   β€’ Ready for offline use!")
    
    return output_path

if __name__ == "__main__":
    # Input and output paths
    input_model = "2_block_best_ai_detector.pth"  # Your trained model
    output_model = "truly_complete_ai_detector_new.pth"  # New complete model
    
    if not os.path.exists(input_model):
        print(f"❌ Trained model not found: {input_model}")
        # Show available files
        pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
        if pth_files:
            print("Available .pth files:")
            for f in pth_files:
                size = os.path.getsize(f) / (1024*1024)
                print(f"   β€’ {f} ({size:.1f}MB)")
        else:
            print("No .pth files found in current directory")
    else:
        try:
            result_path = create_truly_complete_model(input_model, output_model)
            print("\n" + "="*60)
            print(f"πŸŽ‰ SUCCESS! Use this file in your offline app:")
            print(f"πŸ“ {result_path}")
            print("="*60)
        except Exception as e:
            print(f"❌ Failed to create complete model: {e}")