PixDLM / extract.py
WhynotHug's picture
Upload folder using huggingface_hub
3334467 verified
Raw
History Blame Contribute Delete
918 Bytes
#!/usr/bin/env python3
import argparse
from pathlib import Path
import torch
def main():
parser = argparse.ArgumentParser(description="Extract PixDLM alignment weights.")
parser.add_argument("--model", default="pretrained/pixdlm-7b/pytorch_model.bin")
parser.add_argument("--output", default="pretrained/pixdlm-7b/alignment_weights.pth")
args = parser.parse_args()
model_path = Path(args.model)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
state = torch.load(model_path, map_location="cpu")
prefixes = (
"model.sam_to_embed_conv",
"model.image_feature_neck",
"model.mm_projector_for_mask",
)
extracted = {k: v for k, v in state.items() if k.startswith(prefixes)}
torch.save(extracted, output_path)
print(f"Saved {len(extracted)} tensors to {output_path}")
if __name__ == "__main__":
main()