| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Usage example: |
| | python scripts/vlm/gemma3_automodel.py --model llava-hf/llava-v1.6-mistral-7b-hf --data_path naver-clova-ix/cord-v2 |
| | |
| | Used Automodel for training and finetuning HF models. More details can be found at: |
| | https://docs.nvidia.com/nemo-framework/user-guide/latest/automodel/index.html |
| | """ |
| |
|
| |
|
| | import fiddle as fdl |
| | import lightning.pytorch as pl |
| | import torch |
| | from lightning.pytorch.loggers import WandbLogger |
| |
|
| | from nemo import lightning as nl |
| | from nemo.automodel.dist_utils import FirstRankPerNode |
| | from nemo.collections import llm, vlm |
| | from nemo.collections.vlm.hf.data.automodel_datasets import ( |
| | mk_hf_vlm_dataset_cord_v2, |
| | mk_hf_vlm_dataset_fineweb_edu, |
| | mk_hf_vlm_dataset_rdr, |
| | ) |
| |
|
| |
|
| | def make_strategy(strategy, model, devices, num_nodes, adapter_only=False): |
| | if strategy == 'auto': |
| | return pl.strategies.SingleDeviceStrategy( |
| | device='cuda:0', |
| | checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| | ) |
| | elif strategy == 'ddp': |
| | return pl.strategies.DDPStrategy( |
| | checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| | find_unused_parameters=True, |
| | ) |
| | elif strategy == 'fsdp2': |
| | return nl.FSDP2Strategy( |
| | data_parallel_size=devices * num_nodes, |
| | tensor_parallel_size=1, |
| | checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| | ) |
| | else: |
| | raise NotImplementedError("Encountered unknown strategy") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--name', default='gemma3_automodel') |
| | parser.add_argument('--model', type=str, default="google/gemma-3-4b-it") |
| | parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp2']) |
| | parser.add_argument('--devices', default=1, type=int) |
| | parser.add_argument('--num-nodes', default=1, type=int) |
| | parser.add_argument('--mbs', default=1, type=int) |
| | parser.add_argument('--gbs', default=4, type=int) |
| | parser.add_argument( |
| | "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" |
| | ) |
| | parser.add_argument('--accelerator', default='gpu', choices=['gpu']) |
| | parser.add_argument('--max-steps', type=int, default=1000) |
| | parser.add_argument('--wandb-project', type=str, default=None) |
| | parser.add_argument('--disable-ckpt', action='store_false') |
| | parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true") |
| | parser.add_argument( |
| | "--data_path", |
| | type=str, |
| | default="quintend/rdr-items", |
| | help="Path to the dataset. Can be a local path or a HF dataset name", |
| | ) |
| | parser.add_argument("--peft", type=str, default="none", choices=["lora", "none"], help="Which peft to use") |
| | parser.add_argument("--freeze-vision-model", action="store_true", help="Freeze the vision model parameters") |
| | parser.add_argument("--freeze-language-model", action="store_true", help="Freeze the language model parameters") |
| | args = parser.parse_args() |
| |
|
| | dataset_fn = None |
| | |
| | |
| | if "rdr-items" in args.data_path: |
| | dataset_fn = mk_hf_vlm_dataset_rdr |
| | elif "cord-v2" in args.data_path: |
| | dataset_fn = mk_hf_vlm_dataset_cord_v2 |
| | elif "fineweb-edu" in args.data_path: |
| | dataset_fn = mk_hf_vlm_dataset_fineweb_edu |
| | else: |
| | raise NotImplementedError |
| |
|
| | processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) |
| |
|
| | with FirstRankPerNode(): |
| | dataset = dataset_fn(args.data_path, processor, args.mbs, args.gbs) |
| |
|
| | model_kwargs = {} |
| |
|
| | |
| | if "google/gemma-3" in args.model: |
| | model_kwargs["attn_implementation"] = "eager" |
| |
|
| | |
| | |
| | if args.freeze_language_model: |
| | raise ValueError("Freezing language model is not supported for current version of VLM automodel") |
| |
|
| | model = vlm.HFAutoModelForImageTextToText( |
| | args.model, |
| | load_in_4bit=args.use_4bit, |
| | processor=processor, |
| | freeze_language_model=args.freeze_language_model, |
| | freeze_vision_model=args.freeze_vision_model, |
| | **model_kwargs, |
| | ) |
| |
|
| | peft = None |
| | if args.peft == 'lora': |
| | peft = llm.peft.LoRA( |
| | target_modules=['*_proj'], |
| | dim=16, |
| | lora_dtype=torch.bfloat16 if args.use_4bit else None, |
| | ) |
| | nemo_logger = nl.NeMoLogger( |
| | log_dir=args.log_dir, |
| | name=args.name, |
| | wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, |
| | ) |
| |
|
| | llm.finetune( |
| | model=model, |
| | data=dataset, |
| | trainer=nl.Trainer( |
| | devices=args.devices, |
| | max_steps=args.max_steps, |
| | accelerator=args.accelerator, |
| | strategy=make_strategy(args.strategy, model, args.devices, args.num_nodes, adapter_only=False), |
| | log_every_n_steps=1, |
| | limit_val_batches=0.0, |
| | num_sanity_val_steps=0, |
| | accumulate_grad_batches=max(1, args.gbs // args.mbs), |
| | gradient_clip_val=1, |
| | use_distributed_sampler=False, |
| | enable_checkpointing=args.disable_ckpt, |
| | precision='bf16-mixed', |
| | num_nodes=args.num_nodes, |
| | ), |
| | |
| | |
| | optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), |
| | log=nemo_logger, |
| | peft=peft, |
| | ) |
| |
|