ai_docker / saveFullModel.py
yatinece's picture
Add application file
4e78d70
import torch
import torch.nn as nn
from torchvision import models
import os
class AIDetectorModel(nn.Module):
def __init__(self, num_classes=2, dropout_prob=0.3, load_pretrained=True):
super(AIDetectorModel, self).__init__()
if load_pretrained:
print("πŸ“₯ Loading RegNet with pre-trained weights...")
self.backbone = models.regnet_y_16gf(weights=models.RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1)
else:
self.backbone = models.regnet_y_16gf(weights=None)
# Freeze most layers
for param in self.backbone.parameters():
param.requires_grad = False
# Unfreeze last block
if hasattr(self.backbone, 'trunk_output') and hasattr(self.backbone.trunk_output, 'block4'):
self.backbone.trunk_output.block4.requires_grad_(True)
# Replace average pooling with max pooling
self.backbone.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1))
# Get feature dimension
num_ftrs = self.backbone.fc.in_features
# Replace classifier
self.backbone.fc = nn.Sequential(
nn.Linear(num_ftrs, 2048),
nn.SiLU(),
nn.Dropout(dropout_prob),
nn.Linear(2048, 1024),
nn.SiLU(),
nn.Dropout(dropout_prob),
nn.Linear(1024, 512),
nn.SiLU(),
nn.Dropout(dropout_prob),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
def verify_complete_backbone(model):
"""Verify that model has complete RegNet backbone - CORRECTED"""
state_dict = model.state_dict()
all_keys = list(state_dict.keys())
# Check for actual RegNet components based on real structure
stem_keys = [k for k in all_keys if 'backbone.stem' in k]
block_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k]
proj_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k and 'proj' in k]
fc_keys = [k for k in all_keys if 'backbone.fc' in k]
print(f"πŸ” Backbone verification:")
print(f" β€’ Total keys: {len(all_keys)}")
print(f" β€’ Stem keys: {len(stem_keys)}")
print(f" β€’ Block keys: {len(block_keys)}")
print(f" β€’ Projection keys: {len(proj_keys)}")
print(f" β€’ FC keys: {len(fc_keys)}")
# CORRECTED: RegNet Y 16GF has: ~420 block params, ~24 proj layers, 6 stem params
is_complete = (
len(block_keys) > 400 and # Should have ~420 block parameters
len(stem_keys) > 5 and # Should have 6 stem parameters
len(proj_keys) > 20 # Should have ~24 projection parameters
)
if is_complete:
print("βœ… Complete backbone verified!")
print(f"βœ… Found {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj layers")
else:
print("❌ Incomplete backbone detected!")
print(f"Expected: >400 blocks, >5 stem, >20 proj")
print(f"Found: {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj")
return is_complete
def create_truly_complete_model(trained_checkpoint_path, output_path):
"""Create a truly complete model with full RegNet weights"""
print("πŸ”„ Creating TRULY complete offline model...")
print("=" * 60)
# Step 1: Create fresh model WITH pretrained weights
print("πŸ“₯ Step 1: Loading fresh RegNet with ImageNet weights...")
fresh_model = AIDetectorModel(load_pretrained=True)
# Step 2: Verify fresh model is complete (FIXED VALIDATION)
print("\nπŸ” Step 2: Verifying fresh model completeness...")
if not verify_complete_backbone(fresh_model):
raise RuntimeError("❌ Fresh model is not complete!")
# Step 3: Load your trained weights
print(f"\nπŸ“‚ Step 3: Loading your trained weights from {trained_checkpoint_path}...")
trained_checkpoint = torch.load(trained_checkpoint_path, map_location='cpu')
if isinstance(trained_checkpoint, dict) and 'model_state_dict' in trained_checkpoint:
trained_state = trained_checkpoint['model_state_dict']
extra_info = {
'num_classes': trained_checkpoint.get('num_classes', 2),
'dropout_prob': trained_checkpoint.get('dropout_prob', 0.3),
'epoch': trained_checkpoint.get('epoch', 0),
'val_loss': trained_checkpoint.get('val_loss', 0.0)
}
else:
trained_state = trained_checkpoint
extra_info = {'num_classes': 2, 'dropout_prob': 0.3}
# Step 4: Merge weights intelligently
print("\nπŸ”€ Step 4: Merging backbone + trained classifier...")
fresh_state = fresh_model.state_dict()
# Keep ALL backbone weights from fresh model (includes RegNet pretrained)
# Replace ONLY classifier weights from trained model
merged_state = fresh_state.copy()
classifier_keys = [k for k in trained_state.keys() if 'backbone.fc' in k]
print(f" β€’ Replacing {len(classifier_keys)} classifier parameters")
for key in classifier_keys:
if key in trained_state:
merged_state[key] = trained_state[key]
print(f" βœ… Updated: {key}")
# Step 5: Create final model and verify
print("\nβœ… Step 5: Creating final complete model...")
final_model = AIDetectorModel(load_pretrained=False) # Empty architecture
final_model.load_state_dict(merged_state) # Load merged weights
# Final verification (should definitely pass now)
print("\nπŸ” Final verification:")
if not verify_complete_backbone(final_model):
raise RuntimeError("❌ Final model verification failed!")
# Step 6: Save complete model
print(f"\nπŸ’Ύ Step 6: Saving complete model to {output_path}...")
complete_checkpoint = {
'model_state_dict': merged_state,
'complete_model': True,
'backbone_included': True,
'offline_ready': True,
**extra_info
}
torch.save(complete_checkpoint, output_path)
file_size = os.path.getsize(output_path) / (1024*1024)
print(f"βœ… SUCCESS! Complete model saved:")
print(f" β€’ File: {output_path}")
print(f" β€’ Size: {file_size:.1f}MB")
print(f" β€’ Ready for offline use!")
return output_path
if __name__ == "__main__":
# Input and output paths
input_model = "2_block_best_ai_detector.pth" # Your trained model
output_model = "truly_complete_ai_detector_new.pth" # New complete model
if not os.path.exists(input_model):
print(f"❌ Trained model not found: {input_model}")
# Show available files
pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
if pth_files:
print("Available .pth files:")
for f in pth_files:
size = os.path.getsize(f) / (1024*1024)
print(f" β€’ {f} ({size:.1f}MB)")
else:
print("No .pth files found in current directory")
else:
try:
result_path = create_truly_complete_model(input_model, output_model)
print("\n" + "="*60)
print(f"πŸŽ‰ SUCCESS! Use this file in your offline app:")
print(f"πŸ“ {result_path}")
print("="*60)
except Exception as e:
print(f"❌ Failed to create complete model: {e}")