File size: 441 Bytes
91e3dad
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch

def pre_trained_model_to_finetune(checkpoint, args):
    checkpoint = checkpoint['model']
    # only delete the class_embed since the finetuned dataset has different num_classes
    num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers
    for l in range(num_layers):
        del checkpoint["class_embed.{}.weight".format(l)]
        del checkpoint["class_embed.{}.bias".format(l)]
    
    return checkpoint