#!/usr/bin/env python3 import argparse from pathlib import Path import torch def main(): parser = argparse.ArgumentParser(description="Extract PixDLM alignment weights.") parser.add_argument("--model", default="pretrained/pixdlm-7b/pytorch_model.bin") parser.add_argument("--output", default="pretrained/pixdlm-7b/alignment_weights.pth") args = parser.parse_args() model_path = Path(args.model) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) state = torch.load(model_path, map_location="cpu") prefixes = ( "model.sam_to_embed_conv", "model.image_feature_neck", "model.mm_projector_for_mask", ) extracted = {k: v for k, v in state.items() if k.startswith(prefixes)} torch.save(extracted, output_path) print(f"Saved {len(extracted)} tensors to {output_path}") if __name__ == "__main__": main()