File size: 8,360 Bytes
4c0c48f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import torch
def load_weights_by_order(model, checkpoint_path, map_location='cpu'):
"""
Custom function to precisely load weights to mmseg ViT model.
Args:
model: The mmseg VisionTransformer model to load weights into
checkpoint_path: Path to the checkpoint file
map_location: Device mapping for loading the checkpoint
Returns:
Loaded model with weights from checkpoint
"""
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=map_location)
# Extract model state dict if nested in 'model' key
if 'model' in checkpoint:
checkpoint = checkpoint['model']
elif 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
# Get the target model's state dict to check for existing keys
model_state_dict = model.state_dict()
# Create a new state dict for the transformed weights
new_state_dict = {}
# Handle patch embedding
if 'patch_embed.proj.weight' in checkpoint:
new_state_dict['patch_embed.projection.weight'] = checkpoint['patch_embed.proj.weight']
# Note: Skip patch_embed.proj.bias if not in the model state dict
if 'patch_embed.proj.bias' in checkpoint:
if 'patch_embed.projection.bias' in model_state_dict:
new_state_dict['patch_embed.projection.bias'] = checkpoint['patch_embed.proj.bias']
else:
print("Skipping patch_embed.projection.bias as it's not in the model")
# Handle transformer layers
for i in range(12): # Assuming 12 transformer layers
# Layer normalization weights and biases
new_state_dict[f'layers.{i}.ln1.weight'] = checkpoint[f'blocks.{i}.norm1.weight']
new_state_dict[f'layers.{i}.ln1.bias'] = checkpoint[f'blocks.{i}.norm1.bias']
new_state_dict[f'layers.{i}.ln2.weight'] = checkpoint[f'blocks.{i}.norm2.weight']
new_state_dict[f'layers.{i}.ln2.bias'] = checkpoint[f'blocks.{i}.norm2.bias']
# Feed-forward network weights and biases
new_state_dict[f'layers.{i}.ffn.layers.0.0.weight'] = checkpoint[f'blocks.{i}.mlp.fc1.weight']
new_state_dict[f'layers.{i}.ffn.layers.0.0.bias'] = checkpoint[f'blocks.{i}.mlp.fc1.bias']
new_state_dict[f'layers.{i}.ffn.layers.1.weight'] = checkpoint[f'blocks.{i}.mlp.fc2.weight']
new_state_dict[f'layers.{i}.ffn.layers.1.bias'] = checkpoint[f'blocks.{i}.mlp.fc2.bias']
# Handle attention mechanism
qkv_weight = checkpoint[f'blocks.{i}.attn.qkv.weight']
qkv_bias = checkpoint[f'blocks.{i}.attn.qkv.bias']
new_state_dict[f'layers.{i}.attn.attn.in_proj_weight'] = qkv_weight
new_state_dict[f'layers.{i}.attn.attn.in_proj_bias'] = qkv_bias
# Handle out_proj weights and biases
new_state_dict[f'layers.{i}.attn.attn.out_proj.weight'] = checkpoint[f'blocks.{i}.attn.proj.weight']
new_state_dict[f'layers.{i}.attn.attn.out_proj.bias'] = checkpoint[f'blocks.{i}.attn.proj.bias']
# Handle the final layer norm - try multiple potential names
if 'fc_norm.weight' in checkpoint:
# Try different potential target keys for the norm layer
if 'ln1.weight' in model_state_dict:
new_state_dict['ln1.weight'] = checkpoint['fc_norm.weight']
new_state_dict['ln1.bias'] = checkpoint['fc_norm.bias']
elif 'norm.weight' in model_state_dict:
new_state_dict['norm.weight'] = checkpoint['fc_norm.weight']
new_state_dict['norm.bias'] = checkpoint['fc_norm.bias']
# Special case for the final norm layer - try different source keys
if 'ln1.weight' in model_state_dict and 'ln1.weight' not in new_state_dict:
# If not mapped yet, try final block norm as a fallback
if 'blocks.11.norm2.weight' in checkpoint:
print("Using blocks.11.norm2 weights for ln1")
new_state_dict['ln1.weight'] = checkpoint['blocks.11.norm2.weight']
new_state_dict['ln1.bias'] = checkpoint['blocks.11.norm2.bias']
elif 'fc_norm.weight' in checkpoint:
print("Using fc_norm weights for ln1")
new_state_dict['ln1.weight'] = checkpoint['fc_norm.weight']
new_state_dict['ln1.bias'] = checkpoint['fc_norm.bias']
elif 'norm.weight' in checkpoint:
print("Using norm weights for ln1")
new_state_dict['ln1.weight'] = checkpoint['norm.weight']
new_state_dict['ln1.bias'] = checkpoint['norm.bias']
# Handle positional embedding and class token if needed
if hasattr(model, 'pos_embed') and 'pos_embed' in checkpoint:
checkpoint_pos_embed = checkpoint['pos_embed']
if 'pos_embed' in model_state_dict:
model_pos_embed_shape = model_state_dict['pos_embed'].shape
# Resize positional embedding if shapes don't match
if checkpoint_pos_embed.shape != model_pos_embed_shape:
print(f"Resizing positional embedding from {checkpoint_pos_embed.shape} to {model_pos_embed_shape}")
# For simplicity, keep as is if they're close in shape
new_state_dict['pos_embed'] = checkpoint_pos_embed
else:
new_state_dict['pos_embed'] = checkpoint_pos_embed
if hasattr(model, 'cls_token') and 'cls_token' in checkpoint:
if 'cls_token' in model_state_dict:
new_state_dict['cls_token'] = checkpoint['cls_token']
# Check for shape mismatches and fix if possible
keys_to_remove = []
for key, value in new_state_dict.items():
if key in model_state_dict:
if value.shape != model_state_dict[key].shape:
print(f"Shape mismatch for {key}: checkpoint {value.shape} vs model {model_state_dict[key].shape}")
# Check if the MultiheadAttention layer needs special handling
if 'attn.attn.in_proj' in key:
# This requires careful handling of the QKV weights which might have different structures
# between TIMM and PyTorch MultiheadAttention
checkpoint_dim = value.shape[0]
model_dim = model_state_dict[key].shape[0]
if checkpoint_dim % 3 == 0 and model_dim % 3 == 0:
# Try to reshape and reorder if needed
try:
# This is a simplified handling - might need more complex transformation
# based on specific model architectures
embed_dim = checkpoint_dim // 3
# Keep the weight as is and let PyTorch handle it
pass
except Exception as e:
print(f"Failed to process {key}: {e}")
keys_to_remove.append(key)
else:
keys_to_remove.append(key)
else:
# For other layers, try simple reshaping if dimensions allow
try:
if value.numel() == model_state_dict[key].numel():
new_state_dict[key] = value.reshape(model_state_dict[key].shape)
else:
keys_to_remove.append(key)
except Exception:
keys_to_remove.append(key)
else:
# Remove keys that don't exist in the model
keys_to_remove.append(key)
# Remove problematic keys
for key in keys_to_remove:
if key in new_state_dict:
del new_state_dict[key]
# Load the state dict
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print(f"Loaded {len(new_state_dict)} keys into model")
print(f"Missing keys: {len(missing_keys)}")
print(f"Unexpected keys: {len(unexpected_keys)}")
# Print some key missing/unexpected entries for debugging
if missing_keys:
print(f"Sample missing keys: {missing_keys[:5]}")
if unexpected_keys:
print(f"Sample unexpected keys: {unexpected_keys[:5]}")
return model |