DataDecide-dolma1_7-4M / convert_olmo_to_hf.py
JollyTensor's picture
Upload 7 files
79e1917 verified
import argparse
import logging
import os
import shutil
import torch
from omegaconf import OmegaConf as om
from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.modeling_olmo import OLMoForCausalLM
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer
logger = logging.getLogger(__name__)
def write_config(checkpoint_dir: str):
# save config as HF config
logger.info(f"Loading checkpoint from {checkpoint_dir}")
config_path = os.path.join(checkpoint_dir, "config.yaml")
model_config = ModelConfig.load(config_path, key="model")
config_kwargs = model_config.asdict()
config_kwargs["use_cache"] = True
config = OLMoConfig(**config_kwargs)
logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(checkpoint_dir)
def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.
old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
state_dict = torch.load(old_model_path)
new_state_dict = {f"{OLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
torch.save(new_state_dict, new_model_path)
if ignore_olmo_compatibility:
os.remove(old_model_path)
def write_tokenizer(checkpoint_dir: str):
tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir)
tokenizer = OLMoTokenizerFast(
tokenizer_object=tokenizer_raw.base_tokenizer,
truncation=tokenizer_raw.truncate_direction,
max_length=tokenizer_raw.truncate_to,
eos_token=tokenizer_raw.decode([tokenizer_raw.eos_token_id], skip_special_tokens=False),
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
tokenizer.pad_token_id = tokenizer_raw.pad_token_id
tokenizer.eos_token_id = tokenizer_raw.eos_token_id
tokenizer.save_pretrained(checkpoint_dir)
def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
write_config(checkpoint_dir)
write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility)
write_tokenizer(checkpoint_dir)
# Cannot remove it before writing the tokenizer
if ignore_olmo_compatibility:
os.remove(os.path.join(checkpoint_dir, "config.yaml"))
def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str):
from cached_path import cached_path
model_name = os.path.basename(checkpoint_dir)
local_model_path = os.path.join(local_dir, model_name)
os.makedirs(local_model_path, exist_ok=True)
model_files = ["model.pt", "config.yaml"] # , "optim.pt", "other.pt"]
for filename in model_files:
final_location = os.path.join(local_model_path, filename)
if not os.path.exists(final_location):
remote_file = os.path.join(checkpoint_dir, filename)
logger.debug(f"Downloading file {filename}")
cached_file = cached_path(remote_file)
shutil.copy(cached_file, final_location)
logger.debug(f"File at {final_location}")
else:
logger.info(f"File already present at {final_location}")
convert_checkpoint(local_model_path)
return local_model_path
def fix_bad_tokenizer(checkpoint_dir: str):
path = os.path.join(checkpoint_dir, "config.yaml")
conf = om.load(path)
conf["tokenizer"]["identifier"] = "allenai/gpt-neox-olmo-dolma-v1_5"
conf["model"]["eos_token_id"] = 50279
om.save(conf, path)
def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
"making it easier to load weights as HF models."
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)
parser.add_argument(
"--ignore-olmo-compatibility",
action="store_true",
help="Ignore compatibility with the olmo codebase. "
"This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.",
)
args = parser.parse_args()
fix_bad_tokenizer(args.checkpoint_dir)
convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility)
if __name__ == "__main__":
main()