File size: 918 Bytes
3334467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/env python3
import argparse
from pathlib import Path

import torch


def main():
    parser = argparse.ArgumentParser(description="Extract PixDLM alignment weights.")
    parser.add_argument("--model", default="pretrained/pixdlm-7b/pytorch_model.bin")
    parser.add_argument("--output", default="pretrained/pixdlm-7b/alignment_weights.pth")
    args = parser.parse_args()

    model_path = Path(args.model)
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    state = torch.load(model_path, map_location="cpu")
    prefixes = (
        "model.sam_to_embed_conv",
        "model.image_feature_neck",
        "model.mm_projector_for_mask",
    )
    extracted = {k: v for k, v in state.items() if k.startswith(prefixes)}
    torch.save(extracted, output_path)
    print(f"Saved {len(extracted)} tensors to {output_path}")


if __name__ == "__main__":
    main()