{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "d13f7470", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "import torch\n", "print(torch.cuda.is_available())" ] }, { "cell_type": "code", "execution_count": 3, "id": "d944a23b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in /opt/jhub-venv/lib/python3.12/site-packages (2.6.0+cu124)\n", "Requirement already satisfied: transformers in ./myenv/lib/python3.12/site-packages (4.57.1)\n", "Requirement already satisfied: peft in ./myenv/lib/python3.12/site-packages (0.18.0)\n", "Requirement already satisfied: bitsandbytes in ./myenv/lib/python3.12/site-packages (0.48.2)\n", "Requirement already satisfied: trl in ./myenv/lib/python3.12/site-packages (0.25.1)\n", "Requirement already satisfied: datasets in ./myenv/lib/python3.12/site-packages (4.4.1)\n", "Requirement already satisfied: accelerate in ./myenv/lib/python3.12/site-packages (1.11.0)\n", "Requirement already satisfied: jsonlines in ./myenv/lib/python3.12/site-packages (4.0.0)\n", "Requirement already satisfied: filelock in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.19.1)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (4.15.0)\n", "Requirement already satisfied: networkx in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (2025.9.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (3.2.0)\n", "Requirement already satisfied: setuptools in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (70.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /opt/jhub-venv/lib/python3.12/site-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/jhub-venv/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in ./myenv/lib/python3.12/site-packages (from transformers) (0.36.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (2.3.3)\n", "Requirement already satisfied: packaging>=20.0 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (6.0.3)\n", "Requirement already satisfied: regex!=2019.12.17 in ./myenv/lib/python3.12/site-packages (from transformers) (2025.11.3)\n", "Requirement already satisfied: requests in /opt/jhub-venv/lib/python3.12/site-packages (from transformers) (2.32.5)\n", "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in ./myenv/lib/python3.12/site-packages (from transformers) (0.22.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in ./myenv/lib/python3.12/site-packages (from transformers) (0.6.2)\n", "Requirement already satisfied: tqdm>=4.27 in ./myenv/lib/python3.12/site-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: psutil in /opt/jhub-venv/lib/python3.12/site-packages (from peft) (7.1.3)\n", "Requirement already satisfied: pyarrow>=21.0.0 in ./myenv/lib/python3.12/site-packages (from datasets) (22.0.0)\n", "Requirement already satisfied: dill<0.4.1,>=0.3.0 in ./myenv/lib/python3.12/site-packages (from datasets) (0.4.0)\n", "Requirement already satisfied: pandas in ./myenv/lib/python3.12/site-packages (from datasets) (2.3.3)\n", "Requirement already satisfied: httpx<1.0.0 in /opt/jhub-venv/lib/python3.12/site-packages (from datasets) (0.28.1)\n", "Requirement already satisfied: xxhash in ./myenv/lib/python3.12/site-packages (from datasets) (3.6.0)\n", "Requirement already satisfied: multiprocess<0.70.19 in ./myenv/lib/python3.12/site-packages (from datasets) (0.70.18)\n", "Requirement already satisfied: attrs>=19.2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from jsonlines) (25.4.0)\n", "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in ./myenv/lib/python3.12/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.2)\n", "Requirement already satisfied: anyio in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (4.11.0)\n", "Requirement already satisfied: certifi in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (2025.10.5)\n", "Requirement already satisfied: httpcore==1.* in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n", "Requirement already satisfied: idna in /opt/jhub-venv/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (3.11)\n", "Requirement already satisfied: h11>=0.16 in /opt/jhub-venv/lib/python3.12/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in ./myenv/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /opt/jhub-venv/lib/python3.12/site-packages (from requests->transformers) (3.4.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/jhub-venv/lib/python3.12/site-packages (from requests->transformers) (2.5.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/jhub-venv/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/jhub-venv/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in ./myenv/lib/python3.12/site-packages (from pandas->datasets) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/jhub-venv/lib/python3.12/site-packages (from pandas->datasets) (2025.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n", "Requirement already satisfied: aiosignal>=1.4.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n", "Requirement already satisfied: propcache>=0.2.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in ./myenv/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n", "Requirement already satisfied: six>=1.5 in /opt/jhub-venv/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n", "Requirement already satisfied: sniffio>=1.1 in /opt/jhub-venv/lib/python3.12/site-packages (from anyio->httpx<1.0.0->datasets) (1.3.1)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install torch transformers peft bitsandbytes trl datasets accelerate jsonlines" ] }, { "cell_type": "markdown", "id": "20479002", "metadata": {}, "source": [ "## 라이브러리 설정" ] }, { "cell_type": "code", "execution_count": 4, "id": "9eba4d06", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/codeit01team/myenv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch # ← 추가!\n", "\n", "# 데이터 로드\n", "from datasets import load_dataset\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", ")\n", "from peft import (\n", " LoraConfig,\n", " get_peft_model,\n", " prepare_model_for_kbit_training,\n", " TaskType,\n", ")\n", "from trl import SFTTrainer, SFTConfig" ] }, { "cell_type": "code", "execution_count": 5, "id": "804cb86f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/codeit01team'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pwd" ] }, { "cell_type": "markdown", "id": "62466cf3", "metadata": {}, "source": [ "# 1. 기본 QLoRA 설정 (HuggingFace PEFT + BitsAndBytes)" ] }, { "cell_type": "code", "execution_count": 6, "id": "0b56e9fe", "metadata": {}, "outputs": [], "source": [ "# 1. 4-bit 양자화 설정\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.bfloat16\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "9a2f3aa9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 6/6 [02:13<00:00, 22.28s/it]\n" ] } ], "source": [ "# 2. 모델 로드\n", "model_name = \"beomi/Llama-3-Open-Ko-8B\"\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " device_map=\"auto\",\n", " trust_remote_code=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "9885213c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "0e8d4d55", "metadata": {}, "outputs": [], "source": [ "# 3. kbit 학습을 위한 모델 준비\n", "model = prepare_model_for_kbit_training(model)" ] }, { "cell_type": "code", "execution_count": 10, "id": "d8a4198b", "metadata": {}, "outputs": [], "source": [ "# 4. LoRA 설정 (핵심 모듈만 - 메모리 절약)\n", "lora_config = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=TaskType.CAUSAL_LM,\n", " target_modules=[\n", " \"q_proj\",\n", " \"k_proj\", \n", " \"v_proj\",\n", " \"o_proj\",\n", " # gate_proj, up_proj, down_proj 제거 - 성능 차이 크지 않음\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "7eaa969c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 13,631,488 || all params: 8,043,892,736 || trainable%: 0.1695\n" ] } ], "source": [ "# 5. LoRA 적용\n", "model = get_peft_model(model, lora_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "markdown", "id": "dc847247", "metadata": {}, "source": [ "# 2. 데이터셋 준비 및 학습\n", "\n", "`HuggingFace Dataset`을 사용해야하는 이유\n", "\n", "- SFTTrainer가 HuggingFace Dataset을 입력으로 받음\n", "- 자동으로 batching, shuffling, tokenization 처리\n", "- 메모리 효율적 (lazy loading)" ] }, { "cell_type": "code", "execution_count": 12, "id": "17143379", "metadata": {}, "outputs": [], "source": [ "# 데이터 로드 (streaming)\n", "data_path = \"data/sft_train_llama.jsonl\"\n", "dataset = load_dataset(\n", " \"json\",\n", " data_files=data_path,\n", " split=\"train\", \n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "b47e061d", "metadata": {}, "outputs": [], "source": [ "# 학습 설정 - 메모리 최적화\n", "sft_config = SFTConfig(\n", " output_dir=\"./qlora_output\",\n", " num_train_epochs=2, # ← 주석처리 또는 삭제\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=16,\n", " learning_rate=2e-4,\n", " bf16=True,\n", " logging_steps=10,\n", " save_steps=500, # ← epoch 대신 step 기준 저장\n", " optim=\"paged_adamw_8bit\",\n", " gradient_checkpointing=True,\n", " max_length=512, \n", " dataset_text_field=\"text\",\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "5b82e6f5", "metadata": {}, "outputs": [], "source": [ "# SFTTrainer로 학습\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=sft_config,\n", " train_dataset=dataset,\n", " processing_class=tokenizer,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "dd78a9cb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128001, 'bos_token_id': 128000, 'pad_token_id': 128001}.\n" ] }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "
|---|---|
| 10 | \n", "2.525100 | \n", "
| 20 | \n", "1.933300 | \n", "
| 30 | \n", "1.800700 | \n", "
| 40 | \n", "1.775700 | \n", "
| 50 | \n", "1.773700 | \n", "
| 60 | \n", "1.719500 | \n", "
| 70 | \n", "1.715900 | \n", "
| 80 | \n", "1.681300 | \n", "
| 90 | \n", "1.671000 | \n", "
| 100 | \n", "1.690100 | \n", "
| 110 | \n", "1.712000 | \n", "
| 120 | \n", "1.609300 | \n", "
| 130 | \n", "1.614200 | \n", "
| 140 | \n", "1.663300 | \n", "
| 150 | \n", "1.624000 | \n", "
| 160 | \n", "1.586000 | \n", "
| 170 | \n", "1.614200 | \n", "
| 180 | \n", "1.570300 | \n", "
| 190 | \n", "1.609900 | \n", "
| 200 | \n", "1.586800 | \n", "
| 210 | \n", "1.523200 | \n", "
| 220 | \n", "1.595500 | \n", "
| 230 | \n", "1.604200 | \n", "
| 240 | \n", "1.518400 | \n", "
| 250 | \n", "1.551400 | \n", "
| 260 | \n", "1.521200 | \n", "
| 270 | \n", "1.585300 | \n", "
| 280 | \n", "1.575400 | \n", "
| 290 | \n", "1.507000 | \n", "
| 300 | \n", "1.539600 | \n", "
| 310 | \n", "1.489900 | \n", "
| 320 | \n", "1.459300 | \n", "
| 330 | \n", "1.555300 | \n", "
| 340 | \n", "1.520400 | \n", "
| 350 | \n", "1.549200 | \n", "
| 360 | \n", "1.530700 | \n", "
| 370 | \n", "1.532300 | \n", "
| 380 | \n", "1.479400 | \n", "
| 390 | \n", "1.469400 | \n", "
| 400 | \n", "1.470800 | \n", "
| 410 | \n", "1.505100 | \n", "
| 420 | \n", "1.472500 | \n", "
| 430 | \n", "1.477300 | \n", "
| 440 | \n", "1.467300 | \n", "
| 450 | \n", "1.459700 | \n", "
| 460 | \n", "1.484500 | \n", "
| 470 | \n", "1.499100 | \n", "
| 480 | \n", "1.459900 | \n", "
| 490 | \n", "1.430800 | \n", "
| 500 | \n", "1.484700 | \n", "
| 510 | \n", "1.459500 | \n", "
| 520 | \n", "1.437000 | \n", "
| 530 | \n", "1.433800 | \n", "
| 540 | \n", "1.363500 | \n", "
| 550 | \n", "1.348800 | \n", "
| 560 | \n", "1.360600 | \n", "
| 570 | \n", "1.307000 | \n", "
| 580 | \n", "1.350000 | \n", "
| 590 | \n", "1.436000 | \n", "
| 600 | \n", "1.402600 | \n", "
| 610 | \n", "1.369600 | \n", "
| 620 | \n", "1.421000 | \n", "
| 630 | \n", "1.377700 | \n", "
| 640 | \n", "1.365100 | \n", "
| 650 | \n", "1.326400 | \n", "
| 660 | \n", "1.414200 | \n", "
| 670 | \n", "1.400100 | \n", "
| 680 | \n", "1.330200 | \n", "
| 690 | \n", "1.380400 | \n", "
| 700 | \n", "1.357300 | \n", "
| 710 | \n", "1.387900 | \n", "
| 720 | \n", "1.368100 | \n", "
| 730 | \n", "1.312700 | \n", "
| 740 | \n", "1.354500 | \n", "
| 750 | \n", "1.343500 | \n", "
| 760 | \n", "1.371200 | \n", "
| 770 | \n", "1.292800 | \n", "
| 780 | \n", "1.356000 | \n", "
| 790 | \n", "1.353400 | \n", "
| 800 | \n", "1.406300 | \n", "
| 810 | \n", "1.376100 | \n", "
| 820 | \n", "1.297200 | \n", "
| 830 | \n", "1.405000 | \n", "
| 840 | \n", "1.373500 | \n", "
| 850 | \n", "1.338300 | \n", "
| 860 | \n", "1.368300 | \n", "
| 870 | \n", "1.398800 | \n", "
| 880 | \n", "1.337500 | \n", "
| 890 | \n", "1.367700 | \n", "
| 900 | \n", "1.312600 | \n", "
| 910 | \n", "1.353600 | \n", "
| 920 | \n", "1.317400 | \n", "
| 930 | \n", "1.348200 | \n", "
| 940 | \n", "1.361800 | \n", "
| 950 | \n", "1.290600 | \n", "
| 960 | \n", "1.384400 | \n", "
| 970 | \n", "1.290200 | \n", "
| 980 | \n", "1.348800 | \n", "
| 990 | \n", "1.330100 | \n", "
| 1000 | \n", "1.384700 | \n", "
| 1010 | \n", "1.368200 | \n", "
| 1020 | \n", "1.347500 | \n", "
| 1030 | \n", "1.332400 | \n", "
| 1040 | \n", "1.315800 | \n", "
| 1050 | \n", "1.348300 | \n", "
"
],
"text/plain": [
"