| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| from pathlib import Path |
|
|
| from nemo import lightning as nl |
| from nemo.collections import llm |
|
|
| """ |
| Script to import a Hugging Face model checkpoint into NeMo 2.0 format. |
| |
| Example usage: |
| |
| python test_hf_import.py \ |
| --hf_model /path/to/hf/model \ |
| --model LlamaModel \ |
| --config Llama31Config8B \ |
| --output_path /path/to/nemo2/model |
| |
| The source model can be a local directory or a model from HF Hub. |
| It may have different parameters than specified in the config. For example, |
| it be a small model with just two layers and lower hidden dimension size. |
| In this case, configuration will be overriden from the input HF config. |
| |
| Finally, the output NeMo model is loaded using the Fabric API of pl.Trainer. |
| """ |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='Test Llama2 7B model model conversion from HF') |
| parser.add_argument('--hf_model', type=str, help="Original HF model") |
| parser.add_argument("--model", default="LlamaModel", help="Model class from nemo.collections.llm module") |
| parser.add_argument("--config", default="Llama2Config7B", help="Config class from nemo.collections.llm module") |
| parser.add_argument('--output_path', type=str, help="NeMo 2.0 export path") |
| parser.add_argument('--overwrite', action="store_true", help="Overwrite the output model if exists") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args() |
|
|
| ModelClass = getattr(llm, args.model) |
| ModelConfig = getattr(llm, args.config) |
| model = ModelClass(config=ModelConfig) |
| nemo2_path = llm.import_ckpt( |
| model=model, |
| source="hf://" + args.hf_model, |
| output_path=Path(args.output_path), |
| overwrite=args.overwrite, |
| ) |
|
|
| trainer = nl.Trainer( |
| devices=1, |
| strategy=nl.MegatronStrategy(tensor_model_parallel_size=1), |
| plugins=nl.MegatronMixedPrecision(precision='fp16'), |
| ) |
| fabric = trainer.to_fabric() |
| trainer.strategy.setup_environment() |
| fabric.load_model(nemo2_path) |
|
|