{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "d13f7470", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "import torch\n", "print(torch.cuda.is_available())" ] }, { "cell_type": "code", "execution_count": 3, "id": "d944a23b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in /opt/jhub-venv/lib/python3.12/site-packages (2.6.0+cu124)\n", "Requirement already satisfied: transformers in ./myenv/lib/python3.12/site-packages (4.57.1)\n", "Requirement already satisfied: peft in ./myenv/lib/python3.12/site-packages (0.18.0)\n", "Requirement already satisfied: bitsandbytes in ./myenv/lib/python3.12/site-packages (0.48.2)\n", "Requirement already satisfied: trl in ./myenv/lib/python3.12/site-packages (0.25.1)\n", "Requirement already satisfied: datasets in ./myenv/lib/python3.12/site-packages (4.4.1)\n", "Requirement already satisfied: accelerate in ./myenv/lib/python3.12/site-packages (1.11.0)\n", "Requirement already satisfied: jsonlines in ./myenv/lib/python3.12/site-packages (4.0.0)\n", "Requirement already satisfied: filelock in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.19.1)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (4.15.0)\n", "Requirement already satisfied: networkx in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (2025.9.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.2.0)\n", "Requirement already satisfied: setuptools in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (70.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/jhub-venv/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in ./myenv/lib/python3.12/site-packages (from transformers) (0.36.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (2.3.3)\n", "Requirement already satisfied: packaging>=20.0 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (6.0.3)\n", "Requirement already satisfied: regex!=2019.12.17 in ./myenv/lib/python3.12/site-packages (from transformers) (2025.11.3)\n", "Requirement already satisfied: requests in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (2.32.5)\n", "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in ./myenv/lib/python3.12/site-packages (from transformers) (0.22.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in ./myenv/lib/python3.12/site-packages (from transformers) (0.6.2)\n", "Requirement already satisfied: tqdm>=4.27 in ./myenv/lib/python3.12/site-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: psutil in /opt/jhub-venv/lib/python3.12/site-packages (from peft) (7.1.3)\n", "Requirement already satisfied: pyarrow>=21.0.0 in ./myenv/lib/python3.12/site-packages (from datasets) (22.0.0)\n", "Requirement already satisfied: dill<0.4.1,>=0.3.0 in ./myenv/lib/python3.12/site-packages (from datasets) (0.4.0)\n", "Requirement already satisfied: pandas in ./myenv/lib/python3.12/site-packages (from datasets) (2.3.3)\n", "Requirement already satisfied: httpx<1.0.0 in /opt/jhub-venv/lib/python3.12/site-packages (from datasets) (0.28.1)\n", "Requirement already satisfied: xxhash in ./myenv/lib/python3.12/site-packages (from datasets) (3.6.0)\n", "Requirement already satisfied: multiprocess<0.70.19 in ./myenv/lib/python3.12/site-packages (from datasets) (0.70.18)\n", "Requirement already satisfied: attrs>=19.2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from jsonlines) (25.4.0)\n", "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in ./myenv/lib/python3.12/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.2)\n", "Requirement already satisfied: anyio in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (4.11.0)\n", "Requirement already satisfied: certifi in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (2025.10.5)\n", "Requirement already satisfied: httpcore==1.* in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n", "Requirement already satisfied: idna in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (3.11)\n", "Requirement already satisfied: h11>=0.16 in /opt/jhub-venv/lib/python3.12/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in ./myenv/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /opt/jhub-venv/lib/python3.12/site-packages (from requests->transformers) (3.4.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/jhub-venv/lib/python3.12/site-packages (from requests->transformers) (2.5.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/jhub-venv/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in ./myenv/lib/python3.12/site-packages (from pandas->datasets) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/jhub-venv/lib/python3.12/site-packages (from pandas->datasets) (2025.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n", "Requirement already satisfied: aiosignal>=1.4.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n", "Requirement already satisfied: propcache>=0.2.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n", "Requirement already satisfied: six>=1.5 in /opt/jhub-venv/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n", "Requirement already satisfied: sniffio>=1.1 in /opt/jhub-venv/lib/python3.12/site-packages (from anyio->httpx<1.0.0->datasets) (1.3.1)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install torch transformers peft bitsandbytes trl datasets accelerate jsonlines" ] }, { "cell_type": "markdown", "id": "20479002", "metadata": {}, "source": [ "## 라이브러리 설정" ] }, { "cell_type": "code", "execution_count": 4, "id": "9eba4d06", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/codeit01team/myenv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch # ← 추가!\n", "\n", "# 데이터 로드\n", "from datasets import load_dataset\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", ")\n", "from peft import (\n", " LoraConfig,\n", " get_peft_model,\n", " prepare_model_for_kbit_training,\n", " TaskType,\n", ")\n", "from trl import SFTTrainer, SFTConfig" ] }, { "cell_type": "code", "execution_count": 5, "id": "804cb86f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/codeit01team'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pwd" ] }, { "cell_type": "markdown", "id": "62466cf3", "metadata": {}, "source": [ "# 1. 기본 QLoRA 설정 (HuggingFace PEFT + BitsAndBytes)" ] }, { "cell_type": "code", "execution_count": 6, "id": "0b56e9fe", "metadata": {}, "outputs": [], "source": [ "# 1. 4-bit 양자화 설정\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.bfloat16\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "9a2f3aa9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 6/6 [02:13<00:00, 22.28s/it]\n" ] } ], "source": [ "# 2. 모델 로드\n", "model_name = \"beomi/Llama-3-Open-Ko-8B\"\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " device_map=\"auto\",\n", " trust_remote_code=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "9885213c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "0e8d4d55", "metadata": {}, "outputs": [], "source": [ "# 3. kbit 학습을 위한 모델 준비\n", "model = prepare_model_for_kbit_training(model)" ] }, { "cell_type": "code", "execution_count": 10, "id": "d8a4198b", "metadata": {}, "outputs": [], "source": [ "# 4. LoRA 설정 (핵심 모듈만 - 메모리 절약)\n", "lora_config = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=TaskType.CAUSAL_LM,\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\", \n", " \"v_proj\",\n", " \"o_proj\",\n", " # gate_proj, up_proj, down_proj 제거 - 성능 차이 크지 않음\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "7eaa969c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 13,631,488 || all params: 8,043,892,736 || trainable%: 0.1695\n" ] } ], "source": [ "# 5. LoRA 적용\n", "model = get_peft_model(model, lora_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "markdown", "id": "dc847247", "metadata": {}, "source": [ "# 2. 데이터셋 준비 및 학습\n", "\n", "`HuggingFace Dataset`을 사용해야하는 이유\n", "\n", "- SFTTrainer가 HuggingFace Dataset을 입력으로 받음\n", "- 자동으로 batching, shuffling, tokenization 처리\n", "- 메모리 효율적 (lazy loading)" ] }, { "cell_type": "code", "execution_count": 12, "id": "17143379", "metadata": {}, "outputs": [], "source": [ "# 데이터 로드 (streaming)\n", "data_path = \"data/sft_train_llama.jsonl\"\n", "dataset = load_dataset(\n", " \"json\",\n", " data_files=data_path,\n", " split=\"train\", \n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "b47e061d", "metadata": {}, "outputs": [], "source": [ "# 학습 설정 - 메모리 최적화\n", "sft_config = SFTConfig(\n", " output_dir=\"./qlora_output\",\n", " num_train_epochs=2, # ← 주석처리 또는 삭제\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=16,\n", " learning_rate=2e-4,\n", " bf16=True,\n", " logging_steps=10,\n", " save_steps=500, # ← epoch 대신 step 기준 저장\n", " optim=\"paged_adamw_8bit\",\n", " gradient_checkpointing=True,\n", " max_length=512, \n", " dataset_text_field=\"text\",\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "5b82e6f5", "metadata": {}, "outputs": [], "source": [ "# SFTTrainer로 학습\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=sft_config,\n", " train_dataset=dataset,\n", " processing_class=tokenizer,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "dd78a9cb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128001, 'bos_token_id': 128000, 'pad_token_id': 128001}.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1058/1058 6:18:56, Epoch 2/2]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
102.525100
201.933300
301.800700
401.775700
501.773700
601.719500
701.715900
801.681300
901.671000
1001.690100
1101.712000
1201.609300
1301.614200
1401.663300
1501.624000
1601.586000
1701.614200
1801.570300
1901.609900
2001.586800
2101.523200
2201.595500
2301.604200
2401.518400
2501.551400
2601.521200
2701.585300
2801.575400
2901.507000
3001.539600
3101.489900
3201.459300
3301.555300
3401.520400
3501.549200
3601.530700
3701.532300
3801.479400
3901.469400
4001.470800
4101.505100
4201.472500
4301.477300
4401.467300
4501.459700
4601.484500
4701.499100
4801.459900
4901.430800
5001.484700
5101.459500
5201.437000
5301.433800
5401.363500
5501.348800
5601.360600
5701.307000
5801.350000
5901.436000
6001.402600
6101.369600
6201.421000
6301.377700
6401.365100
6501.326400
6601.414200
6701.400100
6801.330200
6901.380400
7001.357300
7101.387900
7201.368100
7301.312700
7401.354500
7501.343500
7601.371200
7701.292800
7801.356000
7901.353400
8001.406300
8101.376100
8201.297200
8301.405000
8401.373500
8501.338300
8601.368300
8701.398800
8801.337500
8901.367700
9001.312600
9101.353600
9201.317400
9301.348200
9401.361800
9501.290600
9601.384400
9701.290200
9801.348800
9901.330100
10001.384700
10101.368200
10201.347500
10301.332400
10401.315800
10501.348300

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=1058, training_loss=1.4723652360119306, metrics={'train_runtime': 22771.7007, 'train_samples_per_second': 0.743, 'train_steps_per_second': 0.046, 'total_flos': 2.071327636721664e+17, 'train_loss': 1.4723652360119306, 'entropy': 1.4031210680802664, 'num_tokens': 4591590.0, 'mean_token_accuracy': 0.6802353163560232, 'epoch': 2.0})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 학습 시작\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 16, "id": "94871386", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('./qlora_adapter/tokenizer_config.json',\n", " './qlora_adapter/special_tokens_map.json',\n", " './qlora_adapter/chat_template.jinja',\n", " './qlora_adapter/tokenizer.json')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 모델 저장 (LoRA 가중치만)\n", "trainer.model.save_pretrained(\"./qlora_adapter\")\n", "tokenizer.save_pretrained(\"./qlora_adapter\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1fbac3b6", "metadata": {}, "outputs": [], "source": [ "# dataset = dataset.shuffle(seed=42).select(range(1000))" ] } ], "metadata": { "kernelspec": { "display_name": "Python (myenv)", "language": "python", "name": "myenv" } }, "nbformat": 4, "nbformat_minor": 5 }