File size: 14,413 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
{
"cells": [
{
"cell_type": "markdown",
"id": "b29a4b72-31bb-4268-9598-2cd2b6f7475e",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# NeVA Training / Inference Tutorial\n",
"\n",
"### Note:\n",
"Currently, this notebook must be run in a NeMo container. An example command to launch the container:\n",
"\n",
"```\n",
"docker run --gpus all -it --rm -v <your_nemo_dir>:/opt/NeMo --shm-size=8g \\\n",
" -p 8888:8888 --ulimit memlock=-1 --ulimit \\\n",
" stack=67108864 <your_nemo_container>\n",
"```\n",
"\n",
"## Introduction\n",
"\n",
"This notebook illustrates how to train and perform inference using NeVA with the NeMo Toolkit. NeVA originates from [LLaVA](https://github.com/haotian-liu/LLaVA) (Large Language and Vision Assistant) and is a powerful multimodal image-text instruction tuned model optimized within the NeMo Framework. \n",
"\n",
"This tutorial will guide you through the following topics:\n",
"1. Prepare pre-requisites for NeVA training\n",
"2. Training a NeVA model\n",
"3. Performing inference with the trained model\n",
"\n",
"## Datasets\n",
"\n",
"### Pre-Training Dataset\n",
"\n",
"The pre-training dataset is open-sourced from the LLaVA implementation and can be downloaded [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain). The dataset consists of a 558K subset of the LAION-CC-SBU dataset with BLIP captions.\n",
"\n",
"The associated images for pretraining can be downloaded via HuggingFace [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/images.zip).\n",
"\n",
"### Instruction Tuning Dataset\n",
"\n",
"The instruction tuning annotations are sourced from the LLaVA implementation and are available [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json).\n",
"\n",
"The associated images for the mixture instruction tuning annotations can be found [here](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#visual-instruction-tuning). After extracting, the data should be formatted as follows:\n",
"\n",
"```\n",
" images\n",
" βββ coco\n",
" β βββ train2017\n",
" βββ gqa\n",
" β βββ images\n",
" βββ ocr_vqa\n",
" β βββ images\n",
" βββ textvqa\n",
" β βββ train_images\n",
" βββ vg\n",
" βββ VG_100K\n",
" βββ VG_100K_2\n",
"```\n",
"\n",
"After downloading all below datasets for pretraining and instruction tuning, please put data folder at `/workspace/datasets`. Your dataset directory should look something similar to:\n",
"\n",
"```\n",
"LLaVA-Pretrain-LCS-558K\n",
"βββ blip_laion_cc_sbu_558k.json\n",
"βββ images\n",
"LLaVA-Instruct-mixture\n",
"βββ llava_v1_5_mix665k.json\n",
"βββ images\n",
" βββ ...\n",
"```\n",
"\n",
"## Setting up Checkpoint and Tokenizer\n",
"\n",
"In this notebook, we first need to convert the Vicuna 1.5 checkpoint into the .nemo format. Meanwhile, special tokens must be incorporated into the tokenizer for NeVA training. After downloading language models from Hugging Face, ensure you also fetch the corresponding tokenizer model. Using the 7B-chat model as a reference."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d80adff-bd3a-40e0-9441-684328ec7596",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"! mkdir -p /workspace/checkpoints\n",
"\n",
"# Download vicuna checkpoint from HF\n",
"! git clone https://huggingface.co/lmsys/vicuna-7b-v1.5 /workspace/checkpoints/vicuna-7b-v1.5\n",
"\n",
"# Convert checkpoint\n",
"! python /opt/NeMo/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \\\n",
" --input_name_or_path /workspace/checkpoints/vicuna-7b-v1.5 \\\n",
" --output_path /workspace/checkpoints/vicuna-7b-v1.5.nemo\n",
"\n",
"# Prepare tokenizer\n",
"! cd /opt && git clone https://github.com/google/sentencepiece.git && \\\n",
" cd sentencepiece && \\\n",
" mkdir build && \\\n",
" cd build && \\\n",
" cmake .. && \\\n",
" make && \\\n",
" make install && \\\n",
" ldconfig && \\\n",
"cd /opt/sentencepiece/src/ && protoc --python_out=/opt/NeMo/scripts/tokenizers/ sentencepiece_model.proto && \\\n",
"export PYTHONPATH=$PYTHONPATH:/opt/NeMo/scripts/tokenizers\n",
"\n",
"! python /opt/NeMo/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \\\n",
"--input_file /workspace/checkpoints/vicuna-7b-v1.5/tokenizer.model \\\n",
"--output_file /workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n",
"--is_userdefined \\\n",
"--tokens \"<extra_id_0>\" \"<extra_id_1>\" \"<extra_id_2>\" \"<extra_id_3>\" \\\n",
" \"<extra_id_4>\" \"<extra_id_5>\" \"<extra_id_6>\" \"<extra_id_7>\"\n"
]
},
{
"cell_type": "markdown",
"id": "6b619e0a",
"metadata": {},
"source": [
"## Training\n",
"\n",
"### Feature Alignment Pre-Training\n",
"\n",
"We provide a set of scripts for pre-training and fine-tuning which can be kicked off with CLI flags defining specified arguments. \n",
"\n",
"An example of a pre-training script execution (note the scripts will only perform 100 steps with a small micro batch size, this is not a full training):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3930351e",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"! torchrun --nproc_per_node=4 /opt/NeMo/examples/multimodal/multimodal_llm/neva/neva_pretrain.py \\\n",
" ++cluster_type=BCP \\\n",
" trainer.precision=bf16 \\\n",
" trainer.num_nodes=1 \\\n",
" trainer.devices=4 \\\n",
" trainer.val_check_interval=50 \\\n",
" trainer.limit_val_batches=5 \\\n",
" trainer.log_every_n_steps=1 \\\n",
" trainer.max_steps=100 \\\n",
" model.megatron_amp_O2=True \\\n",
" model.micro_batch_size=1 \\\n",
" model.global_batch_size=4 \\\n",
" model.tensor_model_parallel_size=1 \\\n",
" model.pipeline_model_parallel_size=1 \\\n",
" model.mcore_gpt=True \\\n",
" model.transformer_engine=True \\\n",
" model.data.data_path=/workspace/datasets/LLaVA-Pretrain-LCS-558K/blip_laion_cc_sbu_558k.json \\\n",
" model.data.image_folder=/workspace/datasets/LLaVA-Pretrain-LCS-558K/images \\\n",
" model.tokenizer.library=sentencepiece \\\n",
" model.tokenizer.model=/workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n",
" model.encoder_seq_length=4096 \\\n",
" model.num_layers=32 \\\n",
" model.hidden_size=4096 \\\n",
" model.ffn_hidden_size=11008 \\\n",
" model.num_attention_heads=32 \\\n",
" model.normalization=rmsnorm \\\n",
" model.do_layer_norm_weight_decay=False \\\n",
" model.apply_query_key_layer_scaling=True \\\n",
" model.bias=False \\\n",
" model.activation=fast-swiglu \\\n",
" model.headscale=False \\\n",
" model.position_embedding_type=rope \\\n",
" model.rotary_percentage=1.0 \\\n",
" model.num_query_groups=null \\\n",
" model.data.num_workers=0 \\\n",
" model.mm_cfg.llm.from_pretrained=/workspace/checkpoints/vicuna-7b-v1.5.nemo \\\n",
" model.mm_cfg.llm.model_type=v1 \\\n",
" model.data.conv_template=v1 \\\n",
" model.mm_cfg.vision_encoder.from_pretrained='openai/clip-vit-large-patch14' \\\n",
" model.mm_cfg.vision_encoder.from_hf=True \\\n",
" model.optim.name=\"fused_adam\" \\\n",
" exp_manager.create_checkpoint_callback=True \\\n",
" exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n",
" exp_manager.create_wandb_logger=False"
]
},
{
"cell_type": "markdown",
"id": "f24ee70d-3025-47f6-8571-295b024c3e05",
"metadata": {},
"source": [
"**Note**: To initialize training a model from scratch rather than from a pretrained checkpoint, you may specify `null` instead of a path in the CLI arguments.\n",
"\n",
"### Image-Language Pair Instruction Fine-Tuning\n",
"\n",
"Fine-tuning can also be run from within the container via a similar command leveraging the `neva_finetune.py` script. We leverage the checkpoint saved from pretrain step to further finetune it, given by `model.restore_from_path=/workspace/nemo_experiments/nemo_neva/checkpoints/nemo_neva.nemo`.\n",
"\n",
"An example of an image-text pair instruction tuning script execution (note the scripts will only perform 1000 steps with a small micro batch size, this is not a full training):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97963224",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"! torchrun --nproc_per_node=4 /opt/NeMo/examples/multimodal/multimodal_llm/neva/neva_finetune.py \\\n",
"++cluster_type=BCP \\\n",
" trainer.precision=bf16 \\\n",
" trainer.num_nodes=1 \\\n",
" trainer.devices=4 \\\n",
" trainer.val_check_interval=50 \\\n",
" trainer.limit_val_batches=50 \\\n",
" trainer.max_steps=100 \\\n",
" model.restore_from_path=/workspace/nemo_experiments/nemo_neva/checkpoints/nemo_neva.nemo \\\n",
" model.megatron_amp_O2=True \\\n",
" model.micro_batch_size=1 \\\n",
" model.global_batch_size=2 \\\n",
" model.tensor_model_parallel_size=4 \\\n",
" model.pipeline_model_parallel_size=1 \\\n",
" model.mcore_gpt=True \\\n",
" model.transformer_engine=True \\\n",
" model.data.data_path=/workspace/datasets/LLaVA-Instruct-mixture/llava_v1_5_mix665k.json \\\n",
" model.data.image_folder=/workspace/datasets/LLaVA-Instruct-mixture/images \\\n",
" model.tokenizer.library=sentencepiece \\\n",
" model.tokenizer.model=/workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n",
" model.encoder_seq_length=4096 \\\n",
" model.num_layers=32 \\\n",
" model.hidden_size=4096 \\\n",
" model.ffn_hidden_size=11008 \\\n",
" model.num_attention_heads=32 \\\n",
" model.normalization=rmsnorm \\\n",
" model.do_layer_norm_weight_decay=False \\\n",
" model.apply_query_key_layer_scaling=True \\\n",
" model.bias=False \\\n",
" model.activation=fast-swiglu \\\n",
" model.headscale=False \\\n",
" model.position_embedding_type=rope \\\n",
" model.rotary_percentage=1.0 \\\n",
" model.num_query_groups=null \\\n",
" model.data.num_workers=0 \\\n",
" model.mm_cfg.llm.from_pretrained=/workspace/checkpoints/vicuna-7b-v1.5.nemo \\\n",
" model.mm_cfg.llm.model_type=v1 \\\n",
" model.data.conv_template=v1 \\\n",
" model.mm_cfg.vision_encoder.from_pretrained='openai/clip-vit-large-patch14' \\\n",
" model.mm_cfg.vision_encoder.from_hf=True \\\n",
" exp_manager.create_checkpoint_callback=True \\\n",
" exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n",
" exp_manager.name=\"nemo_neva_finetune\" \\\n",
" model.optim.name=\"fused_adam\""
]
},
{
"cell_type": "markdown",
"id": "d69e937c",
"metadata": {},
"source": [
"## Inference\n",
"\n",
"### From Pre-trained Checkpoints\n",
"\n",
"If you would like to use NeVA for inference from pre-trained checkpoint via HuggingFace, you can use the checkpoint from fine-tune step or convert from HuggingFace to `.nemo` first. Since we didn't finish full training in this tutorial with NeMo. We will instruct how you can convert a checkpoint from Hugging Face."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f398c26",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"! python3 /opt/NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py \\\n",
" --input_name_or_path llava-hf/llava-1.5-7b-hf \\\n",
" --output_path /workspace/checkpoints/llava-7b.nemo \\\n",
" --tokenizer_path /workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model"
]
},
{
"cell_type": "markdown",
"id": "5235639a",
"metadata": {},
"source": [
"### Running Inference\n",
"\n",
"NeVA inference via the NeMo Framework can be quickly spun up via the NeMo Launcher and a few modifications to use the default NeVA inference config file.\n",
"\n",
"Inference can be run with a similar command leveraging the provided inference script `neva_evaluation.py` within the container.\n",
"\n",
"An example of an inference script execution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee0156ea",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"! echo '{\"image\": \"RTX4080.png\", \"prompt\": \"<image>\\nCan you describe this image?\"}' > sample.jsonl\n",
"! mkdir images && wget https://assets.nvidia.partners/images/png/TUF_Gaming_GeForce_RTX_4080_SUPER_OC_edition_packaging_with_card__12419.png --output-document=images/RTX4080.png\n",
"! torchrun --nproc_per_node=1 /opt/NeMo/examples/multimodal/multimodal_llm/neva/neva_evaluation.py \\\n",
"tensor_model_parallel_size=1 \\\n",
"pipeline_model_parallel_size=1 \\\n",
"neva_model_file=/workspace/checkpoints/llava-7b.nemo \\\n",
"trainer.devices=1 \\\n",
"trainer.precision=bf16 \\\n",
"prompt_file=sample.jsonl \\\n",
"inference.media_base_path=images \\\n",
"output_file=output.jsonl \\\n",
"inference.temperature=0.2 \\\n",
"inference.tokens_to_generate=256"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|