import torch import torch.nn as nn from torchvision import models import os class AIDetectorModel(nn.Module): def __init__(self, num_classes=2, dropout_prob=0.3, load_pretrained=True): super(AIDetectorModel, self).__init__() if load_pretrained: print("šŸ“„ Loading RegNet with pre-trained weights...") self.backbone = models.regnet_y_16gf(weights=models.RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1) else: self.backbone = models.regnet_y_16gf(weights=None) # Freeze most layers for param in self.backbone.parameters(): param.requires_grad = False # Unfreeze last block if hasattr(self.backbone, 'trunk_output') and hasattr(self.backbone.trunk_output, 'block4'): self.backbone.trunk_output.block4.requires_grad_(True) # Replace average pooling with max pooling self.backbone.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) # Get feature dimension num_ftrs = self.backbone.fc.in_features # Replace classifier self.backbone.fc = nn.Sequential( nn.Linear(num_ftrs, 2048), nn.SiLU(), nn.Dropout(dropout_prob), nn.Linear(2048, 1024), nn.SiLU(), nn.Dropout(dropout_prob), nn.Linear(1024, 512), nn.SiLU(), nn.Dropout(dropout_prob), nn.Linear(512, num_classes) ) def forward(self, x): return self.backbone(x) def verify_complete_backbone(model): """Verify that model has complete RegNet backbone - CORRECTED""" state_dict = model.state_dict() all_keys = list(state_dict.keys()) # Check for actual RegNet components based on real structure stem_keys = [k for k in all_keys if 'backbone.stem' in k] block_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k] proj_keys = [k for k in all_keys if 'backbone.trunk_output.block' in k and 'proj' in k] fc_keys = [k for k in all_keys if 'backbone.fc' in k] print(f"šŸ” Backbone verification:") print(f" • Total keys: {len(all_keys)}") print(f" • Stem keys: {len(stem_keys)}") print(f" • Block keys: {len(block_keys)}") print(f" • Projection keys: {len(proj_keys)}") print(f" • FC keys: {len(fc_keys)}") # CORRECTED: RegNet Y 16GF has: ~420 block params, ~24 proj layers, 6 stem params is_complete = ( len(block_keys) > 400 and # Should have ~420 block parameters len(stem_keys) > 5 and # Should have 6 stem parameters len(proj_keys) > 20 # Should have ~24 projection parameters ) if is_complete: print("āœ… Complete backbone verified!") print(f"āœ… Found {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj layers") else: print("āŒ Incomplete backbone detected!") print(f"Expected: >400 blocks, >5 stem, >20 proj") print(f"Found: {len(block_keys)} blocks, {len(stem_keys)} stem, {len(proj_keys)} proj") return is_complete def create_truly_complete_model(trained_checkpoint_path, output_path): """Create a truly complete model with full RegNet weights""" print("šŸ”„ Creating TRULY complete offline model...") print("=" * 60) # Step 1: Create fresh model WITH pretrained weights print("šŸ“„ Step 1: Loading fresh RegNet with ImageNet weights...") fresh_model = AIDetectorModel(load_pretrained=True) # Step 2: Verify fresh model is complete (FIXED VALIDATION) print("\nšŸ” Step 2: Verifying fresh model completeness...") if not verify_complete_backbone(fresh_model): raise RuntimeError("āŒ Fresh model is not complete!") # Step 3: Load your trained weights print(f"\nšŸ“‚ Step 3: Loading your trained weights from {trained_checkpoint_path}...") trained_checkpoint = torch.load(trained_checkpoint_path, map_location='cpu') if isinstance(trained_checkpoint, dict) and 'model_state_dict' in trained_checkpoint: trained_state = trained_checkpoint['model_state_dict'] extra_info = { 'num_classes': trained_checkpoint.get('num_classes', 2), 'dropout_prob': trained_checkpoint.get('dropout_prob', 0.3), 'epoch': trained_checkpoint.get('epoch', 0), 'val_loss': trained_checkpoint.get('val_loss', 0.0) } else: trained_state = trained_checkpoint extra_info = {'num_classes': 2, 'dropout_prob': 0.3} # Step 4: Merge weights intelligently print("\nšŸ”€ Step 4: Merging backbone + trained classifier...") fresh_state = fresh_model.state_dict() # Keep ALL backbone weights from fresh model (includes RegNet pretrained) # Replace ONLY classifier weights from trained model merged_state = fresh_state.copy() classifier_keys = [k for k in trained_state.keys() if 'backbone.fc' in k] print(f" • Replacing {len(classifier_keys)} classifier parameters") for key in classifier_keys: if key in trained_state: merged_state[key] = trained_state[key] print(f" āœ… Updated: {key}") # Step 5: Create final model and verify print("\nāœ… Step 5: Creating final complete model...") final_model = AIDetectorModel(load_pretrained=False) # Empty architecture final_model.load_state_dict(merged_state) # Load merged weights # Final verification (should definitely pass now) print("\nšŸ” Final verification:") if not verify_complete_backbone(final_model): raise RuntimeError("āŒ Final model verification failed!") # Step 6: Save complete model print(f"\nšŸ’¾ Step 6: Saving complete model to {output_path}...") complete_checkpoint = { 'model_state_dict': merged_state, 'complete_model': True, 'backbone_included': True, 'offline_ready': True, **extra_info } torch.save(complete_checkpoint, output_path) file_size = os.path.getsize(output_path) / (1024*1024) print(f"āœ… SUCCESS! Complete model saved:") print(f" • File: {output_path}") print(f" • Size: {file_size:.1f}MB") print(f" • Ready for offline use!") return output_path if __name__ == "__main__": # Input and output paths input_model = "2_block_best_ai_detector.pth" # Your trained model output_model = "truly_complete_ai_detector_new.pth" # New complete model if not os.path.exists(input_model): print(f"āŒ Trained model not found: {input_model}") # Show available files pth_files = [f for f in os.listdir('.') if f.endswith('.pth')] if pth_files: print("Available .pth files:") for f in pth_files: size = os.path.getsize(f) / (1024*1024) print(f" • {f} ({size:.1f}MB)") else: print("No .pth files found in current directory") else: try: result_path = create_truly_complete_model(input_model, output_model) print("\n" + "="*60) print(f"šŸŽ‰ SUCCESS! Use this file in your offline app:") print(f"šŸ“ {result_path}") print("="*60) except Exception as e: print(f"āŒ Failed to create complete model: {e}")