| |
| |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed.checkpoint as DCP |
| from transformers import AutoModelForCausalLM |
|
|
| import fla |
| from torchtitan.tools.logging import init_logger, logger |
|
|
|
|
| @torch.inference_mode() |
| def convert_hf_weights(model: str, checkpoint: str): |
| logger.info(f"Loading model from {model}") |
| model = AutoModelForCausalLM.from_pretrained(model) |
| state_dict = model.state_dict() |
|
|
| logger.info(f"Writing to DCP at '{checkpoint}'") |
| checkpoint.mkdir(parents=True, exist_ok=True) |
| storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) |
| DCP.save({"model": state_dict}, storage_writer=storage_writer) |
|
|
|
|
| if __name__ == "__main__": |
| init_logger() |
| parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.") |
| parser.add_argument("--model", type=str, required=True) |
| parser.add_argument("--checkpoint", type=Path, required=True) |
| args = parser.parse_args() |
|
|
| convert_hf_weights(args.model, args.checkpoint) |
|
|