| import torch
|
|
|
| def load_weights_by_order(model, checkpoint_path, map_location='cpu'):
|
| """
|
| Custom function to precisely load weights to mmseg ViT model.
|
|
|
| Args:
|
| model: The mmseg VisionTransformer model to load weights into
|
| checkpoint_path: Path to the checkpoint file
|
| map_location: Device mapping for loading the checkpoint
|
|
|
| Returns:
|
| Loaded model with weights from checkpoint
|
| """
|
| print(f"Loading checkpoint from {checkpoint_path}")
|
| checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
|
|
|
| if 'model' in checkpoint:
|
| checkpoint = checkpoint['model']
|
| elif 'state_dict' in checkpoint:
|
| checkpoint = checkpoint['state_dict']
|
|
|
|
|
| model_state_dict = model.state_dict()
|
|
|
|
|
| new_state_dict = {}
|
|
|
|
|
| if 'patch_embed.proj.weight' in checkpoint:
|
| new_state_dict['patch_embed.projection.weight'] = checkpoint['patch_embed.proj.weight']
|
|
|
|
|
| if 'patch_embed.proj.bias' in checkpoint:
|
| if 'patch_embed.projection.bias' in model_state_dict:
|
| new_state_dict['patch_embed.projection.bias'] = checkpoint['patch_embed.proj.bias']
|
| else:
|
| print("Skipping patch_embed.projection.bias as it's not in the model")
|
|
|
|
|
| for i in range(12):
|
|
|
| new_state_dict[f'layers.{i}.ln1.weight'] = checkpoint[f'blocks.{i}.norm1.weight']
|
| new_state_dict[f'layers.{i}.ln1.bias'] = checkpoint[f'blocks.{i}.norm1.bias']
|
| new_state_dict[f'layers.{i}.ln2.weight'] = checkpoint[f'blocks.{i}.norm2.weight']
|
| new_state_dict[f'layers.{i}.ln2.bias'] = checkpoint[f'blocks.{i}.norm2.bias']
|
|
|
|
|
| new_state_dict[f'layers.{i}.ffn.layers.0.0.weight'] = checkpoint[f'blocks.{i}.mlp.fc1.weight']
|
| new_state_dict[f'layers.{i}.ffn.layers.0.0.bias'] = checkpoint[f'blocks.{i}.mlp.fc1.bias']
|
| new_state_dict[f'layers.{i}.ffn.layers.1.weight'] = checkpoint[f'blocks.{i}.mlp.fc2.weight']
|
| new_state_dict[f'layers.{i}.ffn.layers.1.bias'] = checkpoint[f'blocks.{i}.mlp.fc2.bias']
|
|
|
|
|
| qkv_weight = checkpoint[f'blocks.{i}.attn.qkv.weight']
|
| qkv_bias = checkpoint[f'blocks.{i}.attn.qkv.bias']
|
|
|
| new_state_dict[f'layers.{i}.attn.attn.in_proj_weight'] = qkv_weight
|
| new_state_dict[f'layers.{i}.attn.attn.in_proj_bias'] = qkv_bias
|
|
|
|
|
| new_state_dict[f'layers.{i}.attn.attn.out_proj.weight'] = checkpoint[f'blocks.{i}.attn.proj.weight']
|
| new_state_dict[f'layers.{i}.attn.attn.out_proj.bias'] = checkpoint[f'blocks.{i}.attn.proj.bias']
|
|
|
|
|
| if 'fc_norm.weight' in checkpoint:
|
|
|
| if 'ln1.weight' in model_state_dict:
|
| new_state_dict['ln1.weight'] = checkpoint['fc_norm.weight']
|
| new_state_dict['ln1.bias'] = checkpoint['fc_norm.bias']
|
| elif 'norm.weight' in model_state_dict:
|
| new_state_dict['norm.weight'] = checkpoint['fc_norm.weight']
|
| new_state_dict['norm.bias'] = checkpoint['fc_norm.bias']
|
|
|
|
|
| if 'ln1.weight' in model_state_dict and 'ln1.weight' not in new_state_dict:
|
|
|
| if 'blocks.11.norm2.weight' in checkpoint:
|
| print("Using blocks.11.norm2 weights for ln1")
|
| new_state_dict['ln1.weight'] = checkpoint['blocks.11.norm2.weight']
|
| new_state_dict['ln1.bias'] = checkpoint['blocks.11.norm2.bias']
|
| elif 'fc_norm.weight' in checkpoint:
|
| print("Using fc_norm weights for ln1")
|
| new_state_dict['ln1.weight'] = checkpoint['fc_norm.weight']
|
| new_state_dict['ln1.bias'] = checkpoint['fc_norm.bias']
|
| elif 'norm.weight' in checkpoint:
|
| print("Using norm weights for ln1")
|
| new_state_dict['ln1.weight'] = checkpoint['norm.weight']
|
| new_state_dict['ln1.bias'] = checkpoint['norm.bias']
|
|
|
|
|
| if hasattr(model, 'pos_embed') and 'pos_embed' in checkpoint:
|
| checkpoint_pos_embed = checkpoint['pos_embed']
|
| if 'pos_embed' in model_state_dict:
|
| model_pos_embed_shape = model_state_dict['pos_embed'].shape
|
|
|
|
|
| if checkpoint_pos_embed.shape != model_pos_embed_shape:
|
| print(f"Resizing positional embedding from {checkpoint_pos_embed.shape} to {model_pos_embed_shape}")
|
|
|
| new_state_dict['pos_embed'] = checkpoint_pos_embed
|
| else:
|
| new_state_dict['pos_embed'] = checkpoint_pos_embed
|
|
|
| if hasattr(model, 'cls_token') and 'cls_token' in checkpoint:
|
| if 'cls_token' in model_state_dict:
|
| new_state_dict['cls_token'] = checkpoint['cls_token']
|
|
|
|
|
| keys_to_remove = []
|
| for key, value in new_state_dict.items():
|
| if key in model_state_dict:
|
| if value.shape != model_state_dict[key].shape:
|
| print(f"Shape mismatch for {key}: checkpoint {value.shape} vs model {model_state_dict[key].shape}")
|
|
|
|
|
| if 'attn.attn.in_proj' in key:
|
|
|
|
|
| checkpoint_dim = value.shape[0]
|
| model_dim = model_state_dict[key].shape[0]
|
|
|
| if checkpoint_dim % 3 == 0 and model_dim % 3 == 0:
|
|
|
| try:
|
|
|
|
|
| embed_dim = checkpoint_dim // 3
|
|
|
| pass
|
| except Exception as e:
|
| print(f"Failed to process {key}: {e}")
|
| keys_to_remove.append(key)
|
| else:
|
| keys_to_remove.append(key)
|
| else:
|
|
|
| try:
|
| if value.numel() == model_state_dict[key].numel():
|
| new_state_dict[key] = value.reshape(model_state_dict[key].shape)
|
| else:
|
| keys_to_remove.append(key)
|
| except Exception:
|
| keys_to_remove.append(key)
|
| else:
|
|
|
| keys_to_remove.append(key)
|
|
|
|
|
| for key in keys_to_remove:
|
| if key in new_state_dict:
|
| del new_state_dict[key]
|
|
|
|
|
| missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
|
|
| print(f"Loaded {len(new_state_dict)} keys into model")
|
| print(f"Missing keys: {len(missing_keys)}")
|
| print(f"Unexpected keys: {len(unexpected_keys)}")
|
|
|
|
|
| if missing_keys:
|
| print(f"Sample missing keys: {missing_keys[:5]}")
|
| if unexpected_keys:
|
| print(f"Sample unexpected keys: {unexpected_keys[:5]}")
|
|
|
| return model |