| import argparse |
| import os |
|
|
| import torch |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2") |
|
|
| parser.add_argument("source_model", default="", type=str, |
| help="Source model") |
| parser.add_argument("output_model", default="", type=str, |
| help="Output model") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| if os.path.splitext(args.source_model)[-1] != ".pth": |
| raise ValueError("You should save weights as pth file") |
|
|
| source_weights = torch.load( |
| args.source_model, map_location=torch.device('cpu'))["model"] |
| converted_weights = {} |
| keys = list(source_weights.keys()) |
|
|
| prefix = 'backbone.bottom_up.' |
| for key in keys: |
| converted_weights[prefix+key] = source_weights[key] |
|
|
| torch.save(converted_weights, args.output_model) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|