| import argparse | |
| import torch | |
| from modeling_bitnet import BitnetForCausalLM | |
| from tokenization_bitnet import BitnetTokenizer | |
| torch.set_grad_enabled(False) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) | |
| parser.add_argument("--output_path", default="./bitnet_b1_58-3B_quantized", type=str) | |
| def main(args): | |
| model = BitnetForCausalLM.from_pretrained( | |
| args.hf_path, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| use_flash_attention_2=True, | |
| torch_dtype=torch.float16, | |
| ).half() | |
| tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) | |
| model.quantize() | |
| model.save_pretrained(args.output_path, max_shard_size="5GB") | |
| print("Quantized model saved to", args.output_path) | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| main(args) | |