{ "cells": [ { "cell_type": "markdown", "id": "ec4d51a3-4b0b-40d0-b84d-34e23a523468", "metadata": {}, "source": [ "# Train Your Own Reasoning Model in 48 Hours on a Single GPU\n", "\n", "This tutorial provides a tried-and-true recipe for training your own reasoning model by fine-tuning a [Meta LLaMA 3.1–8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model with NVIDIA NeMo in about 48 hours on a single H100 80GB GPU.\n", "\n", "This recipe is inspired by the [Llama Nemotron family of models](https://www.nvidia.com/en-us/ai-data-science/foundation-models/llama-nemotron/), where the model can selectively turn reasoning on or off based on instructions in the system prompt. You'll train your model on complex instruction-following and reasoning tasks using the [Llama-Nemotron-Post-Training-Data](https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset).\n", "\n", "### βœ… What You'll Learn\n", "1. An effective recipe to train your own reasoning model, similar to [Llama Nemotron reasoning models](https://www.nvidia.com/en-us/ai-data-science/foundation-models/llama-nemotron/).\n", "2. Apply Fine Tuning with NeMo 2.0, using LoRA adapters or full model fine-tuning.\n", "3. Train using NeMo's distributed, mixed-precision trainer.\n", "4. Save a fine-tuned checkpoint ready for evaluation or deployment.\n", "\n", "\n", "### 🧰 Tools and Resources\n", "* [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html)\n", "* [Llama-Nemotron-Post-Training-Data](https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset), an open source dataset for instilling reasoning behavior in large language models.\n", "* [NeMo Curator](https://github.com/NVIDIA/NeMo-Curator) for data curation\n", "\n", "## πŸ“Œ Requirements\n", "\n", "### Prerequisites\n", "\n", "* Access to latest NeMo Framework NGC Containers\n", "* This playbook has been tested on `nvcr.io/nvidia/nemo:25.04.01`. It is expected to work similarly on other environments.\n", "* A valid Hugging Face API token with access to the [Meta LLaMa 3.1-8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model (since this is a gated model).\n", "\n", "### Dataset\n", "To follow along, you would need an appropriate reasoning dataset. Checkout the tutorial on [curating the Llama Nemotron Reasoning Dataset with NVIDIA NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/llama-nemotron-data-curation).\n", "You will need the output from that tutorial as the training set input to this playbook!\n", "\n", "### Hardware Requirements\n", "\n", "You can either fine-tune a LoRA adapter, or fine-tune the entire model.\n", "\n", "* **Parameter Efficient Fine-Tuning via LoRA adapters**: \n", " * This playbook has been verified on **a single H100 80GB** when PEFT is enabled.\n", " * It takes **48 hours** to train the model for 2000 steps and observe reasoning behavior.\n", "* **Full Fine-Tuning**:\n", " * This playbook has been tested on **8xH100 80GB** GPUs. You can scale training to more GPUs as well as multiple nodes by modifying the appropriate parameters.\n", " * It takes **12 hours** to train the model for 2000 steps and observe reasoning behavior.\n", "\n", "\n", "Let's dive in!" ] }, { "cell_type": "markdown", "id": "e4929586", "metadata": {}, "source": [ "## πŸš€ Step 0. Launch the NeMo Framework container\n", "\n", "Run the following command to launch the NeMo Framework training container. Ensure to populate the `HF_TOKEN` variable with a valid API key:\n", "\n", "```bash\n", "docker run -it --rm \\\n", " --gpus all --shm-size=16GB \\\n", " --ipc=host --network host \\\n", " -v $(pwd):/workspace \\\n", " -e HF_TOKEN= \\\n", " nvcr.io/nvidia/nemo:25.04.01\n", "```\n", "\n", "#### Launch Jupyter Notebook as follows:\n", "\n", "```bash\n", "jupyter notebook --allow-root --ip 0.0.0.0 --port 8088 --no-browser --NotebookApp.token=''\n", "```" ] }, { "cell_type": "markdown", "id": "0dd940e9-0c80-4021-966b-3917d6650fe7", "metadata": {}, "source": [ "## πŸ“‹ Step 1. Preparations\n", "Specify the parameters you will use for model training. These include:\n", "\n", "1. Whether you will do PEFT/LoRA or full fine-tuning.\n", "2. How many GPUs you will use for training.\n", "3. How many steps you will do training for. \n", " - ⚠️ **At least 2000 steps is needed to observe reasoning behavior.**\n", "4. The location where the training dataset is stored. This dataset is expected to be in the structured JSONL format, stored locally with the name `training.jsonl`.\n", "5. Where to store the Hugging Face model checkpoint that is converted to NeMo format.\n", "6. The location where the training outputs will be stored.\n", "\n", "Setup the values in the cell below, then execute it." ] }, { "cell_type": "code", "execution_count": null, "id": "de9e0422", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# LoRA or full model training? Set to `False` for full model training.\n", "DO_LORA_ADAPTER_TRAINING = True\n", "# The number of GPUs used for training\n", "N_DEVICES = 1\n", "# The number of training steps (at least 2000 steps are required for a good reasoning model)\n", "N_STEPS = 2000\n", "# Where the training data is stored\n", "TRAINING_DATASET_DIR = \"/path/to/your/data_directory\"\n", "# Where the NeMo conversions are stored\n", "NEMO_CONVERSION_DIR = \"/path/to/your/nemo_conversions\"\n", "# Where the checkpoints and other experiment results are stored\n", "OUTPUT_DIR = \"/path/to/your/output_directory\"\n", "# Where the deployment converted checkpoint are stored\n", "OUTPUT_DEPLOYMENT_DIR = \"/path/to/your/output_deployment\"\n", "\n", "#\n", "# Some basic sanity checks\n", "#\n", "assert os.path.exists(TRAINING_DATASET_DIR), \"Data root directory does not exist. Please set DATA_ROOT to a valid path.\"\n", "# Ensure training.jsonl exists in the data directory\n", "assert os.path.exists(os.path.join(TRAINING_DATASET_DIR, \"training.jsonl\")), \"training.jsonl file does not exist in the specified DATA_ROOT.\"\n", "# Ensure the NeMo converstion directory exists\n", "assert os.path.isdir(NEMO_CONVERSION_DIR), \"NeMo conversion directory does not exist. Please set NEMO_CONVERSION_DIR to a valid path.\"\n", "assert os.path.isdir(OUTPUT_DIR), \"Output directory does not exist. Please set OUTPUT_DIR to a valid path.\"" ] }, { "cell_type": "markdown", "id": "3914f817", "metadata": {}, "source": [ "\n", "\n", "### Convert HuggingFace Checkpoint to NeMo Format\n", "\n", "Before training, we need to convert the HuggingFace LLaMA 3.1–8B Instruct checkpoint into NeMo format. NeMo provides a built-in utility ```llm.import_ckpt()``` to handle this conversion.\n", "After conversion, the model can be loaded and fine-tuned using NeMo APIs directly.\n", "\n", "⚠️ This step only needs to be run once per model." ] }, { "cell_type": "code", "execution_count": null, "id": "013764ee-95ed-44fa-85b9-4296d37e7c24", "metadata": {}, "outputs": [], "source": [ "import nemo_run as run\n", "from nemo import lightning as nl\n", "from nemo.collections import llm\n", "\n", "import pytorch_lightning as pl\n", "from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed\n", "from datetime import datetime\n", "\n", "# Configure the import from HuggingFace format to NeMo format\n", "def configure_checkpoint_conversion():\n", " return run.Partial(\n", " llm.import_ckpt,\n", " model=llm.llama31_8b.model(), # Predefined LLaMA 3.1 8B model structure\n", " source=\"hf://meta-llama/Llama-3.1-8B-Instruct\", # Path to HF checkpoint (local or HF hub)\n", " output_path=NEMO_CONVERSION_DIR, # Directory to save the converted NeMo checkpoint\n", " overwrite=True, # Setting this to False will result in an error if the checkpoint was already converted\n", " )\n", "\n", "# Create the configured import task\n", "import_ckpt = configure_checkpoint_conversion()\n", "\n", "# Define the local executor (single-node)\n", "local_executor = run.LocalExecutor()\n", "\n", "# Execute the checkpoint conversion\n", "run.run(import_ckpt, executor=local_executor)" ] }, { "cell_type": "markdown", "id": "a148354c-af20-44f0-bbda-bf1362fd2c24", "metadata": {}, "source": [ "βœ“ Checkpoint imported to /root/.cache/nemo/models/Meta-Llama-3.1-8B" ] }, { "cell_type": "markdown", "id": "b1492948-5608-4596-b571-f141ec6bde9e", "metadata": {}, "source": [ "## πŸ“‚ Step 2. Prepare Data\n", "\n", "Next, define the configuration for loading and preprocessing an instruction-tuning dataset using NeMo’s FineTuningDataModule. The dataset is expected to be in a structured format (e.g. JSONL), stored locally as ```training.jsonl```.\n", "\n", "The training-related parameters like batch size, number of workers, memory mapping, and device count can be modified based on the size of the model, dataset size and compute resources available. We recommend a context length of at least 8192 for best results.\n", "\n", "⚠️ For this particular dataset, the global batch size must be at least 256 for best results. You can increase this value based on compute resources available to you.\n", "\n", "⚠️ If you observe any out-of-memory issues with CPU memory (i.e. RAM), you can decrease `num_workers`.\n", "\n", "#### A Note on Batching and Gradient Accumulation\n", "The parameter `micro_batch_size` specified how many batches of data are assigned to **each** GPU. The `global_batch_size` parameter specifies the total batch size across **all** devices, **including gradient accumulation**.\n", "\n", "Note that the computation for the amount of gradient accumulation required is done internally by NeMo, and is based on how many GPUs are used for training. In the example below, if we are using a single GPU with `micro_batch_size=1` and `global_batch_size=256`, the number of gradient accumulation steps will be set to `256`." ] }, { "cell_type": "code", "execution_count": null, "id": "ca01f0bc-fdaa-4023-97cb-c5a26f0840a8", "metadata": {}, "outputs": [], "source": [ "timestamp = datetime.now().strftime(\"%Y%m%d-%H%M\")\n", "experiment_name = \"baseline-lora-only\" if DO_LORA_ADAPTER_TRAINING else \"baseline-full-model\"\n", "\n", "# Define fine-tuning dataset configuration\n", "finetune_config = run.Config(\n", " llm.FineTuningDataModule,\n", " dataset_root=TRAINING_DATASET_DIR, # Path to your preprocessed dataset (JSONL, etc.)\n", " seq_length=8192, # Max sequence length for input tokens\n", " micro_batch_size=1, # Per-device batch size. Each GPU will process this many samples at a time\n", " global_batch_size=256, # Total batch size across all devices\n", " seed=1234, # Seed for reproducibility\n", " memmap_workers=1, # Use memory-mapped dataset format for performance\n", " num_workers=8, # DataLoader worker threads\n", " pin_memory=True, # Optimize data transfer to GPU\n", ")" ] }, { "cell_type": "markdown", "id": "aec4d9d6-0b90-4562-a9b8-7ac947ba851f", "metadata": {}, "source": [ "## πŸ› οΈ Step 3. Configure Fine Tuning with the NeMo 2.0 API\n", "\n", "In this step, we'll use the modular NeMo 2.0 API to configure:\n", "\n", "* The distributed trainer\n", "\n", "* Logging and checkpointing\n", "\n", "* Optimizer with cosine annealing scheduler\n", "\n", "* Model definition and resume behavior\n", "\n", "* Final recipe assembly for fine-tuning" ] }, { "cell_type": "markdown", "id": "214912f1-2cd3-4db2-a295-9efb8612a043", "metadata": {}, "source": [ "### βš™οΈ 3.1 Configure the Trainer\n", "\n", "We define the training strategy with Megatron's Distributed Training strategy using tensor model parallelism and enabling mixed precision with bf16.\n", "\n", "For this excercise, we do not have a validation set, so all the settings for validation are set to 0. If you decide to use a validation set, you would need to update these parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "0abe9b85-a9b6-418a-b013-ed5459a1e86b", "metadata": {}, "outputs": [], "source": [ "def trainer() -> run.Config[nl.Trainer]:\n", " strategy = run.Config(\n", " nl.MegatronStrategy,\n", " tensor_model_parallel_size=1,\n", " optimizer_cpu_offload=True\n", " )\n", " trainer = run.Config(\n", " nl.Trainer,\n", " devices=N_DEVICES,\n", " num_nodes=1, # Change to >1 for multi-node training\n", " max_steps=N_STEPS,\n", " accelerator=\"gpu\",\n", " strategy=strategy,\n", " plugins=bf16_mixed(),\n", " log_every_n_steps=10,\n", " limit_val_batches=0,\n", " val_check_interval=0,\n", " num_sanity_val_steps=0,\n", " use_distributed_sampler=False,\n", " )\n", " return trainer" ] }, { "cell_type": "markdown", "id": "2e88dd49-314f-42fb-a4fb-30ad3d4b490d", "metadata": {}, "source": [ "### πŸ“ 3.2 Configure Logging and Checkpointing\n", "This cell defines the logging mechanism and model checkpointing during training." ] }, { "cell_type": "code", "execution_count": null, "id": "802508e1-8051-431f-bd54-44ffe4d292d0", "metadata": {}, "outputs": [], "source": [ "def logger() -> run.Config[nl.NeMoLogger]:\n", " ckpt = run.Config(\n", " nl.ModelCheckpoint,\n", " save_last=True,\n", " every_n_train_steps=100, # Checkpoint frequency\n", " monitor=\"reduced_train_loss\",\n", " save_top_k=2, # How many top checkpoints to keep\n", " save_on_train_epoch_end=True,\n", " save_optim_on_train_end=True,\n", " )\n", "\n", " return run.Config(\n", " nl.NeMoLogger,\n", " name=f\"{experiment_name}-trained-model-checkpoints\",\n", " log_dir=f\"{OUTPUT_DIR}/results-{timestamp}-{N_DEVICES}-devices-{experiment_name}\",\n", " use_datetime_version=True,\n", " ckpt=ckpt,\n", " wandb=None\n", " )" ] }, { "cell_type": "markdown", "id": "4489a8ab-be37-406b-a11e-b6a597604614", "metadata": {}, "source": [ "### πŸ“ˆ 3.3 Configure Optimizer with Cosine Annealing\n", "We use the Adam optimizer with gradient clipping, distributed optimizer support, and a cosine annealing learning rate schedule.\n", "\n", "We recommend starting off with a high learning rate (`1e-4`) and gradually decreasing it for best results. For this recipe, we use a 5% warmup schedule. We recommend setting this to be between 0% to 10% of the training for most use-cases." ] }, { "cell_type": "code", "execution_count": null, "id": "8badf059-c275-436c-ab5e-5db2d6c32839", "metadata": {}, "outputs": [], "source": [ "from nemo.lightning.pytorch.optim import CosineAnnealingScheduler\n", "from megatron.core.optimizer import OptimizerConfig\n", "\n", "def lr_scheduler():\n", " warmup_steps = int(N_STEPS * 0.05) # 5% of total steps for warmup\n", " return run.Config(\n", " CosineAnnealingScheduler,\n", " warmup_steps=warmup_steps,\n", " constant_steps=0, # Do not fix the learning rate, always decrease it\n", " min_lr=1e-5,\n", " )\n", "\n", "def adam_with_cosine_annealing() -> run.Config[nl.OptimizerModule]:\n", " opt_cfg = run.Config(\n", " OptimizerConfig,\n", " optimizer=\"adam\",\n", " lr=1e-4,\n", " weight_decay=0.001,\n", " use_distributed_optimizer=True,\n", " clip_grad=1.0,\n", " bf16=True,\n", " )\n", "\n", " return run.Config(\n", " nl.MegatronOptimizerModule,\n", " config=opt_cfg,\n", " lr_scheduler=lr_scheduler(),\n", " )" ] }, { "cell_type": "markdown", "id": "f17225cf-6fa1-48ef-95e8-54b57efcf967", "metadata": {}, "source": [ "### 🧠 3.4 Define the Base Model and Resume Logic\n", "We use the built-in LLaMA 3.1 8B config from NeMo and optionally resume from a previously saved checkpoint." ] }, { "cell_type": "code", "execution_count": null, "id": "5a003f8c-4947-49b8-bd24-8712fcf87532", "metadata": {}, "outputs": [], "source": [ "def llama31_8b() -> run.Config[pl.LightningModule]:\n", " return run.Config(llm.LlamaModel, config=run.Config(llm.Llama31Config8B))\n", "\n", "def resume() -> run.Config[nl.AutoResume]:\n", " return run.Config(\n", " nl.AutoResume,\n", " restore_config=run.Config(\n", " nl.RestoreConfig,\n", " path=NEMO_CONVERSION_DIR,\n", " ),\n", " resume_if_exists=True,\n", " )" ] }, { "cell_type": "markdown", "id": "7bfe4c09-ea6d-4dd0-9953-cbdd797225ac", "metadata": {}, "source": [ "### πŸ“¦ 3.5 Assemble the Fine-Tuning Recipe\n", "This ties together the model, trainer, dataset config, optimizer, and logger into a single training recipe using NeMo’s run.Partial system.\n", "\n", "#### Recommended LoRA settings\n", "We recommend using a LoRA rank of at least 64 with `alpha = 128` for good results, as lower ranks did not perform as well in our experiments.\n", "\n", "#### Preventing Out-of-Memory Errors for Full Fine-\n", "Fine-tuning an entire 8-billion parameter model requires a lot of GPU memory, and even 8xH100 80GB might not be enough, depending on the model's context size. One way to save memory is to enable [activation recomputation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/activation_recomputation.html). This approach significantly reduces the memory requirements at the cost of slower training speeds.\n", "\n", "In the cell below, we use the most aggressive form of activation recomputation to ensure this notebook can be executed on a single node with 8x80GB GPUs. For more information on available options please refer to the [documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/activation_recomputation.html)." ] }, { "cell_type": "code", "execution_count": null, "id": "8c175cd8-c4fd-4f24-a206-435a79965131", "metadata": {}, "outputs": [], "source": [ "def configure_finetuning_recipe():\n", " peft = None\n", "\n", " # Are we training a LoRA adapter or the full model?\n", " if DO_LORA_ADAPTER_TRAINING:\n", " peft = llm.peft.LoRA(\n", " dim=64,\n", " alpha=128,\n", " )\n", "\n", " recipe = run.Partial(\n", " llm.finetune,\n", " model=llama31_8b(),\n", " trainer=trainer(),\n", " data=finetune_config, # From earlier step\n", " log=logger(),\n", " optim=adam_with_cosine_annealing(),\n", " resume=resume(),\n", " peft=peft, # LoRA adapter configuration if applicable\n", " )\n", "\n", " if not DO_LORA_ADAPTER_TRAINING:\n", " # Use aggressive recomputation settings for full model training to save on GPU memory\n", " recipe.model.config.recompute_granularity = \"full\"\n", " recipe.model.config.recompute_method = \"uniform\"\n", " recipe.model.config.recompute_num_layers = 1\n", "\n", " return recipe" ] }, { "cell_type": "markdown", "id": "1cf5592c-abfb-4489-b8a2-67b4164346b5", "metadata": {}, "source": [ "## ▢️ Step 4: Run Fine-Tuning with NeMo 2.0 and nemo-run\n", "Now that everything is configured (model, trainer, optimizer, logging, and data), it's time to launch the training job using nemo-run's LocalExecutor.\n", "\n", "This will:\n", "\n", "* Use torchrun to launch a multi-GPU job\n", "* Set environment variables for optimized NCCL behavior\n", "* Kick off the training loop with your full configuration\n", "\n", "While the training is ongoing, you can monitor the progress via tensorboard." ] }, { "cell_type": "code", "execution_count": null, "id": "c837f7c7-ad01-4bee-87ba-2e52f919a38b", "metadata": {}, "outputs": [], "source": [ "def local_executor_torchrun() -> run.LocalExecutor:\n", " # Environment variables to optimize distributed training\n", " env_vars = {\n", " \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n", " \"NCCL_NVLS_ENABLE\": \"0\",\n", " }\n", "\n", " return run.LocalExecutor(\n", " ntasks_per_node=N_DEVICES,\n", " launcher=\"torchrun\",\n", " env_vars=env_vars,\n", " )\n", "\n", "# Execute the training run\n", "if __name__ == '__main__':\n", " recipe = configure_finetuning_recipe()\n", " print(recipe)\n", " run.run(\n", " recipe,\n", " executor=local_executor_torchrun()\n", " )" ] }, { "cell_type": "markdown", "id": "bfff9c71", "metadata": {}, "source": [ "For your reference here are the loss plots from our own experiments using 500,000 training samples, with a batch size of 256 and 2000 training steps.\n", "\n", "You might be wondering about the sudden loss drop at the end. This is expected!\n", "The training dataset is arranged in the increasing order of sample difficulty (i.e. curriculum learning).\n", "With 500,000 training samples, a batch size of 256 and 2000 steps, that’s just slightly over 1 epoch of training.\n", "Towards the end of that epoch, when the model sees the first few (easier samples) again, it can easily predict the right tokens for them so the loss ends up being much lower.\n", "\n", "#### LoRA Training Loss Plots\n", "![LoRA Training Loss Plots](images/loss-plot-lora.png)\n", "\n", "#### Full Fine-tuning Loss Plots\n", "![Fine-tuning Loss Plots](images/loss-plot-full-finetuning.png)\n" ] }, { "cell_type": "markdown", "id": "6bcf75b4", "metadata": {}, "source": [ "## πŸ“€ Step 5 (Optional): Export the LoRA Adapter to Hugging Face Format for Deployment\n", "\n", "After successfully training your reasoning model, you may want to export your LoRA adapter to the Hugging Face format for easier deployment and sharing. \n", "This step is **essential** if you plan to deploy your model using popular inference frameworks such as:\n", "* πŸš€ **[NVIDIA NIM](https://developer.nvidia.com/nim?sortBy=developer_learning_library%2Fsort%2Ffeatured_in.nim%3Adesc%2Ctitle%3Aasc&hitsPerPage=12)**\n", "* ⚑ **[vLLM](https://docs.vllm.ai/en/latest/)**\n" ] }, { "cell_type": "code", "execution_count": null, "id": "636e279d", "metadata": {}, "outputs": [], "source": [ "from nemo.collections import llm\n", "import nemo_run as run\n", "\n", "def configure_checkpoint_conversion():\n", " return run.Partial(\n", " llm.export_ckpt,\n", " path=\"OUTPUT_DIR\",\n", " target=\"hf-peft\",\n", " output_path=\"OUTPUT_DEPLOYMENT_DIR\",\n", " overwrite=True, # <-- IMPORTANT!\n", " )\n", "\n", "# configure your function\n", "export_ckpt = configure_checkpoint_conversion()\n", "# define your executor\n", "local_executor = run.LocalExecutor()\n", "\n", "# run your experiment\n", "run.run(export_ckpt, executor=local_executor)" ] }, { "cell_type": "markdown", "id": "e58635c3-08a9-4cec-9fe0-edd0b94036a8", "metadata": {}, "source": [ "## πŸŽ‰ Tada! You Just Trained Your First Reasoning Model!\n", "Congratulations β€” you've successfully fine-tuned LLaMA 3.1-8B Instruct into a reasoning model using NeMo Framework!\n", "\n", "Your model is now ready to:\n", "\n", "* Solve problems through reasoning and analysis.\n", "* Selectively perform reasoning (i.e. `reasoning on` or `reasoning off`).\n", "\n", "### πŸš€ Next Steps\n", "* πŸ§ͺ Evaluate your model on reasoning benchmarks (e.g., MMLU, GSM8K)\n", "* ☁️ Package the model for deployment or inference with Triton or vLLM\n", "* πŸ“€ Optionally, upload it to HuggingFace or NGC to share with the world" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }