{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e6d20008-a91c-4618-baa0-5991e031f1bd", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:48:57.985184Z", "iopub.status.busy": "2025-05-13T21:48:57.984795Z", "iopub.status.idle": "2025-05-13T21:51:48.369715Z", "shell.execute_reply": "2025-05-13T21:51:48.368907Z", "shell.execute_reply.started": "2025-05-13T21:48:57.985144Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq\n", "import torch\n", "from peft import LoraConfig, get_peft_model\n", "\n", "import os\n", "from tqdm import tqdm\n", "import json\n", "\n", "import random\n", "from datasets import load_dataset\n", "from datasets import Dataset, DatasetDict" ] }, { "cell_type": "code", "execution_count": 3, "id": "67f95fc8-a9d8-48cf-a551-7d30781cdb55", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:53:58.075473Z", "iopub.status.busy": "2025-05-13T21:53:58.074767Z", "iopub.status.idle": "2025-05-13T21:53:58.767860Z", "shell.execute_reply": "2025-05-13T21:53:58.767319Z", "shell.execute_reply.started": "2025-05-13T21:53:58.075446Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 8/8 [00:00<00:00, 22.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['messages'],\n", " num_rows: 309\n", " })\n", " test: Dataset({\n", " features: ['messages'],\n", " num_rows: 343\n", " })\n", "})\n" ] } ], "source": [ "data_path = (\n", " \"/root/notebooks/MT_TQ/Caches/May2025/tquality.annotated.data/parsed/pldl/\"\n", ")\n", "\n", "json_files = [\n", " os.path.join(root, file)\n", " for root, _, files in os.walk(data_path)\n", " for file in files\n", " if file.endswith(\".json\")\n", "]\n", "\n", "training_samples = []\n", "testing_samples = []\n", "\n", "for json_file in tqdm(json_files):\n", " with open(json_file, \"r\") as file:\n", " data = json.load(file)\n", " sampled_items = data[\"data\"]\n", " if \"test\" in json_file:\n", " testing_samples.extend(sampled_items)\n", " if \"train\" in json_file:\n", " training_samples.extend(sampled_items)\n", "\n", "training_datapoints = []\n", "testing_datapoints = []\n", "\n", "for idx, sample in enumerate(training_samples):\n", " datapoint = {\"input\": {}}\n", " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n", " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n", " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n", " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n", " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n", " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n", " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n", " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n", " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n", " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n", " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n", " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n", " datapoint[\"input\"][\"id\"] = idx\n", " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n", " training_datapoints.append(datapoint)\n", "\n", "for idx, sample in enumerate(testing_samples):\n", " datapoint = {\"input\": {}}\n", " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n", " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n", " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n", " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n", " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n", " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n", " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n", " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n", " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n", " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n", " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n", " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n", " datapoint[\"input\"][\"id\"] = idx\n", " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n", " testing_datapoints.append(datapoint)\n", "\n", "system_message = \"You are a helpful assistant who is an expert in estimating quality of translations.\"\n", "\n", "output_template = '''\n", "{\n", " \"Accuracy Issues\": [\n", " {\n", " \"Error Span\": \"\",\n", " \"Error Explanation\": \"\",\n", " \"Error Quality Category\": \"\",\n", " \"Error Quality Tags\": [],\n", " \"Error Severity\": \"\"\n", " }\n", " ],\n", " \"Accuracy Score\": \"\",\n", " \"Readability Issues\": [\n", " {\n", " \"Error Span\": \"\",\n", " \"Error Explanation\": \"\",\n", " \"Error Quality Category\": \"\",\n", " \"Error Quality Tags\": [],\n", " \"Error Severity\": \"\"\n", " }\n", " ],\n", " \"Readability Score\": \"\"\n", "}'''\n", "\n", "def create_conversation(input_sample, output_sample):\n", " return {\n", " \"messages\": [\n", " # {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": input_sample},\n", " {\"role\": \"assistant\", \"content\": output_sample}\n", " ]\n", " }\n", "\n", "def create_dataset(datapoints, template_string):\n", " dataset = []\n", " meta = []\n", " for datapoint in datapoints:\n", " src_text = datapoint['input']['src_text']\n", " tgt_text = datapoint['input']['tgt_text']\n", " src_prev = datapoint['input']['src_prev']\n", " src_next = datapoint['input']['src_next'] \n", " tgt_prev = datapoint['input']['tgt_prev']\n", " tgt_next = datapoint['input']['tgt_next']\n", " src_lang = datapoint['input']['src_lang']\n", " tgt_lang = datapoint['input']['tgt_lang']\n", " \n", " start_frame = datapoint['input']['start_frame']\n", " end_frame = datapoint['input']['end_frame']\n", " title_id = datapoint['input']['title_id']\n", " output = datapoint['evaluation']\n", " idx = datapoint['input']['id']\n", " if len(output['Accuracy Issues']) != 0 or len(output['Readability Issues']) != 0:\n", " item = template_string.format(src_text=src_text, tgt_text=tgt_text, \n", " src_prev=src_prev, src_next=src_next, \n", " tgt_prev=tgt_prev, tgt_next=tgt_next, \n", " src_lang=src_lang, tgt_lang=tgt_lang,\n", " template=output_template)\n", " \n", " dataset.append(create_conversation(item, json.dumps(output)))\n", " meta.append({\"id\": idx, \"start_frame\": start_frame, \"end_frame\": end_frame, \"title_id\": title_id})\n", " \n", " return dataset, meta\n", " \n", "def dataset_prep(datapoints):\n", " with open(\"prompts.txt\") as file:\n", " template_string = file.read()\n", " dataset, meta = create_dataset(datapoints, template_string)\n", " return dataset, meta\n", "\n", "train_dataset, train_meta = dataset_prep(training_datapoints)\n", "test_dataset, test_meta = dataset_prep(testing_datapoints)\n", "\n", "dataset = {\"train\": train_dataset, \"test\": test_dataset}\n", "\n", "def convert_to_hf_dataset(dataset):\n", " train_dataset = Dataset.from_list(dataset['train'])\n", " test_dataset = Dataset.from_list(dataset['test'])\n", " \n", " hf_dataset = DatasetDict({\n", " 'train': train_dataset,\n", " 'test': test_dataset\n", " })\n", " \n", " return hf_dataset\n", "\n", "hf_dataset = convert_to_hf_dataset(dataset)\n", "print(hf_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "id": "8b52f143-1077-4da6-ac92-b1dce5cdc17c", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:54:12.568533Z", "iopub.status.busy": "2025-05-13T21:54:12.568078Z", "iopub.status.idle": "2025-05-13T21:54:49.724121Z", "shell.execute_reply": "2025-05-13T21:54:49.723481Z", "shell.execute_reply.started": "2025-05-13T21:54:12.568507Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 12/12 [00:18<00:00, 1.58s/it]\n" ] } ], "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n", "from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n", "device = torch.device(\"cuda:0\")\n", "\n", "# Hugging Face model id\n", "model_id = \"google/gemma-3-27b-it\" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`\n", "\n", "# Select model class based on id\n", "if model_id == \"google/gemma-3-27b-it\":\n", " model_class = Gemma3ForConditionalGeneration\n", "else:\n", " model_class = AutoModelForImageTextToText\n", "\n", "torch_dtype = torch.bfloat16\n", "\n", "model_kwargs = dict(\n", " attn_implementation=\"eager\",\n", " torch_dtype=torch_dtype,\n", " device_map=\"auto\",\n", ")\n", "\n", "model = model_class.from_pretrained(model_id, **model_kwargs)\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-3-27b-it\") # Load the Instruction Tokenizer to use the official Gemma template" ] }, { "cell_type": "code", "execution_count": 5, "id": "8443dfd8-6193-480c-9937-f6e0c43a9f56", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:12.713958Z", "iopub.status.busy": "2025-05-13T21:55:12.713495Z", "iopub.status.idle": "2025-05-13T21:55:12.717707Z", "shell.execute_reply": "2025-05-13T21:55:12.717199Z", "shell.execute_reply.started": "2025-05-13T21:55:12.713930Z" } }, "outputs": [], "source": [ "from peft import LoraConfig\n", "\n", "peft_config = LoraConfig(\n", " lora_alpha=128,\n", " lora_dropout=0.05,\n", " r=16,\n", " bias=\"none\",\n", " target_modules=\"all-linear\",\n", " task_type=\"CAUSAL_LM\",\n", " modules_to_save=[\"lm_head\", \"embed_tokens\"] # make sure to save the lm_head and embed_tokens as you train the special tokens\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "8f2b8371-ba1b-44ff-9462-d0c90335f82a", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:22.076515Z", "iopub.status.busy": "2025-05-13T21:55:22.076029Z", "iopub.status.idle": "2025-05-13T21:55:22.783524Z", "shell.execute_reply": "2025-05-13T21:55:22.782937Z", "shell.execute_reply.started": "2025-05-13T21:55:22.076489Z" } }, "outputs": [], "source": [ "from trl import SFTConfig\n", "\n", "args = SFTConfig(\n", " output_dir=\"may13-gemma-27b-tq_sft_finetuned-model\",\n", " max_seq_length=2048,\n", " packing=True,\n", " num_train_epochs=1,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " gradient_checkpointing=True,\n", " optim=\"adamw_torch_fused\",\n", " logging_steps=1,\n", " save_strategy=\"epoch\",\n", " learning_rate=1e-4,\n", " fp16=True if torch_dtype == torch.float16 else False,\n", " bf16=True if torch_dtype == torch.bfloat16 else False,\n", " max_grad_norm=0.3,\n", " warmup_ratio=0.03,\n", " lr_scheduler_type=\"constant\",\n", " push_to_hub=True,\n", " report_to=\"tensorboard\",\n", " dataset_kwargs={\n", " \"add_special_tokens\": False,\n", " \"append_concat_token\": True,\n", " },\n", " no_cuda=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "2be55b87-70c9-4973-b0db-33154c272e47", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:25.765385Z", "iopub.status.busy": "2025-05-13T21:55:25.764949Z", "iopub.status.idle": "2025-05-13T21:55:36.592163Z", "shell.execute_reply": "2025-05-13T21:55:36.591614Z", "shell.execute_reply.started": "2025-05-13T21:55:25.765360Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Converting train dataset to ChatML: 100%|██████████| 309/309 [00:00<00:00, 9533.70 examples/s]\n", "Applying chat template to train dataset: 100%|██████████| 309/309 [00:00<00:00, 4443.06 examples/s]\n", "Tokenizing train dataset: 100%|██████████| 309/309 [00:01<00:00, 226.22 examples/s]\n", "Packing train dataset: 100%|██████████| 309/309 [00:00<00:00, 102364.74 examples/s]\n", "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n" ] } ], "source": [ "from trl import SFTTrainer\n", "\n", "# Create Trainer object\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=args,\n", " train_dataset=hf_dataset[\"train\"],\n", " peft_config=peft_config,\n", " processing_class=tokenizer\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "d8d82767-27ed-48ed-ad22-3f3cf2dff15e", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T22:00:25.107226Z", "iopub.status.busy": "2025-05-13T22:00:25.106569Z", "iopub.status.idle": "2025-05-13T22:27:35.945604Z", "shell.execute_reply": "2025-05-13T22:27:35.944775Z", "shell.execute_reply.started": "2025-05-13T22:00:25.107196Z" }, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [80/80 22:44, Epoch 0/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
110.801900
28.381400
36.970200
45.784300
54.970800
64.389700
74.325000
83.557000
93.357700
103.092500
113.170300
122.648500
133.067800
142.377100
152.847700
162.628800
172.630800
182.820900
192.596700
202.675300
212.846300
222.706700
232.645100
242.214600
252.790700
262.640700
272.908900
282.690400
292.807200
302.713600
312.563200
322.412700
332.627700
342.431800
352.240600
362.650300
372.014900
382.463100
392.283300
402.450500
412.570400
422.550500
432.530600
442.551400
452.383000
462.550500
472.575900
482.494300
492.387200
502.318800
512.365200
522.190100
532.419100
542.290900
552.152500
562.398700
572.982500
582.380200
592.357500
602.386300
612.741300
622.850300
632.682100
642.972100
652.237800
662.518300
672.520700
682.122700
692.210200
702.414000
712.348200
722.470800
732.417400
742.562900
752.286800
762.671400
772.176200
782.284200
792.354700
802.363400

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.train()\n", "trainer.save_model()" ] }, { "cell_type": "code", "execution_count": 10, "id": "2398696f-eeb8-45d1-8dee-ed88a7ac140b", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T22:34:47.172016Z", "iopub.status.busy": "2025-05-13T22:34:47.171574Z", "iopub.status.idle": "2025-05-13T22:39:06.055171Z", "shell.execute_reply": "2025-05-13T22:39:06.054429Z", "shell.execute_reply.started": "2025-05-13T22:34:47.171989Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n" ] }, { "data": { "text/plain": [ "('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer_config.json',\n", " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/special_tokens_map.json',\n", " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.model',\n", " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/added_tokens.json',\n", " '/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full/tokenizer.json')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lora_model = trainer.model\n", "merged_model = lora_model.merge_and_unload()\n", "# Save the model with fused weights\n", "merged_model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n", "trainer.tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')" ] }, { "cell_type": "code", "execution_count": 1, "id": "8b811a84-0cdb-4b40-bb96-d6e6f27d41d3", "metadata": { "execution": { "iopub.execute_input": "2025-05-08T21:17:00.794785Z", "iopub.status.busy": "2025-05-08T21:17:00.794339Z", "iopub.status.idle": "2025-05-08T21:17:18.309148Z", "shell.execute_reply": "2025-05-08T21:17:18.308319Z", "shell.execute_reply.started": "2025-05-08T21:17:00.794761Z" } }, "outputs": [ { "ename": "NameError", "evalue": "name 'model' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n", "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined" ] } ], "source": [ "# Merge LoRA weights into the base model\n", "for name, param in model.named_parameters():\n", " if name in trainer.peft_model.lora_weights:\n", " param.data += trainer.peft_model.lora_weights[name]\n", "\n", "# Save the model with fused weights\n", "model.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')\n", "tokenizer.save_pretrained('/root/notebooks/MT_TQ/TQ/TQTune/gemma-27b-tq_sft_finetuned-model-full')" ] }, { "cell_type": "code", "execution_count": 9, "id": "e5b4930d-92c5-46e8-9163-6e7f722e0c99", "metadata": { "execution": { "iopub.execute_input": "2025-05-08T19:13:24.762234Z", "iopub.status.busy": "2025-05-08T19:13:24.761972Z", "iopub.status.idle": "2025-05-08T19:13:50.993002Z", "shell.execute_reply": "2025-05-08T19:13:50.992329Z", "shell.execute_reply.started": "2025-05-08T19:13:24.762215Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 12/12 [00:19<00:00, 1.60s/it]\n" ] } ], "source": [ "import torch\n", "from transformers import pipeline\n", "from random import randint\n", "import re\n", "\n", "model_id = \"google/gemma-3-27b-it\"\n", "model = model_class.from_pretrained(\n", " model_id,\n", " device_map=\"auto\",\n", " torch_dtype=torch_dtype,\n", " attn_implementation=\"eager\",\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "5a428dea-261a-4c74-89a8-1b62d7ade5ab", "metadata": { "execution": { "iopub.execute_input": "2025-05-08T19:13:50.999539Z", "iopub.status.busy": "2025-05-08T19:13:50.999160Z", "iopub.status.idle": "2025-05-08T19:15:04.024652Z", "shell.execute_reply": "2025-05-08T19:15:04.022626Z", "shell.execute_reply.started": "2025-05-08T19:13:50.999517Z" } }, "outputs": [ { "ename": "NameError", "evalue": "name 'trainer' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[10], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Merge LoRA weights into the base model\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights:\n\u001b[1;32m 4\u001b[0m param\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mpeft_model\u001b[38;5;241m.\u001b[39mlora_weights[name]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# Save the model with fused weights\u001b[39;00m\n", "\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined" ] } ], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c7e7172e-db49-40f3-a0d6-9a87a3b2cf80", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.026597Z", "iopub.status.idle": "2025-05-08T19:15:04.026984Z", "shell.execute_reply": "2025-05-08T19:15:04.026875Z", "shell.execute_reply.started": "2025-05-08T19:15:04.026863Z" }, "scrolled": true }, "outputs": [], "source": [ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n", "rand_idx = randint(0, len(dataset[\"test\"]))\n", "test_sample = hf_dataset[\"test\"][rand_idx]\n", "stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\")]\n", "prompt = pipe.tokenizer.apply_chat_template(test_sample[\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n", "\n", "outputs = pipe(prompt, max_new_tokens=1024, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "de05b438-ca77-4b95-b2b1-32ea7ae033a5", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.028819Z", "iopub.status.idle": "2025-05-08T19:15:04.029072Z", "shell.execute_reply": "2025-05-08T19:15:04.028971Z", "shell.execute_reply.started": "2025-05-08T19:15:04.028960Z" } }, "outputs": [], "source": [ "start = outputs[0]['generated_text'].split(r\"model\")[1].strip().find(\"{\")\n", "end = outputs[0]['generated_text'].split(r\"model\")[1].strip().rfind(\"}\")\n", "print(start, end)\n", "print(outputs[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1])\n", "json.loads(outputs[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1])\n", "rand_idx" ] }, { "cell_type": "code", "execution_count": null, "id": "60b3da99-0edc-4ef6-b0e0-be7d046eaa02", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.030913Z", "iopub.status.idle": "2025-05-08T19:15:04.031227Z", "shell.execute_reply": "2025-05-08T19:15:04.031122Z", "shell.execute_reply.started": "2025-05-08T19:15:04.031111Z" } }, "outputs": [], "source": [ "json.loads(hf_dataset[\"test\"][81][\"messages\"][1]['content'])" ] }, { "cell_type": "code", "execution_count": null, "id": "cdc44250-e3b9-4870-bce5-23f475023962", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.032999Z", "iopub.status.idle": "2025-05-08T19:15:04.033327Z", "shell.execute_reply": "2025-05-08T19:15:04.033207Z", "shell.execute_reply.started": "2025-05-08T19:15:04.033196Z" }, "scrolled": true }, "outputs": [], "source": [ "import torch\n", "from transformers import pipeline\n", "from random import randint\n", "import re\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n", "from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n", "device = torch.device(\"cuda:0\")\n", "\n", "model_class = Gemma3ForConditionalGeneration\n", "torch_dtype = torch.bfloat16\n", "\n", "model_id = \"gemma-27b-tq_sft_finetuned-model\"\n", "model = model_class.from_pretrained(\n", " model_id,\n", " device_map=\"auto\",\n", " torch_dtype=torch_dtype,\n", " attn_implementation=\"eager\",\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "id": "ce4070a9-5291-477a-bb3f-867b7971e391", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.035085Z", "iopub.status.idle": "2025-05-08T19:15:04.035416Z", "shell.execute_reply": "2025-05-08T19:15:04.035307Z", "shell.execute_reply.started": "2025-05-08T19:15:04.035295Z" } }, "outputs": [], "source": [ "def extract_json_data(json_string):\n", " key_pattern = r'\"(.*?)\"\\s*:\\s*'\n", " value_pattern = r'(?:\"(.*?)\"|(\\d+)|$$(.*?)$$|\\{(.*?)\\})'\n", " matches = re.finditer(key_pattern + value_pattern, json_string, re.DOTALL) \n", " data = {}\n", " for match in matches:\n", " key = match.group(1)\n", " value = match.group(2) or match.group(3) or match.group(4) or match.group(5) \n", " if value:\n", " try:\n", " value = json.loads(value)\n", " except (json.JSONDecodeError, TypeError):\n", " pass\n", " data[key] = value\n", " return data" ] }, { "cell_type": "code", "execution_count": null, "id": "4940ab0c-ff5a-4c1e-a543-b0e8be91a4cb", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.037234Z", "iopub.status.idle": "2025-05-08T19:15:04.037745Z", "shell.execute_reply": "2025-05-08T19:15:04.037637Z", "shell.execute_reply.started": "2025-05-08T19:15:04.037626Z" } }, "outputs": [], "source": [ "rand_idx = randint(0, len(dataset[\"test\"]))\n", "test_predictions = []\n", "\n", "index = 9\n", "\n", "meta_data = test_meta[index]\n", "stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\")]\n", "prompt = pipe.tokenizer.apply_chat_template(hf_dataset[\"test\"][index][\"messages\"][:1], tokenize=False, add_generation_prompt=True)\n", "outputs = pipe(prompt, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)\n", "start = outputs[0]['generated_text'].split(r\"model\")[1].strip().find(\"{\")\n", "end = outputs[0]['generated_text'].split(r\"model\")[1].strip().rfind(\"}\")\n", "try:\n", " pred_dict = json.loads(outputs[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1])\n", "except:\n", " start = outputs[0]['generated_text'].split(r\"model\")[1].strip().find(\"{\")\n", " end = outputs[0]['generated_text'].split(r\"model\")[1].strip().rfind(\"}\")\n", " pred_dict = outputs[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1]" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf03584-7cd0-40cc-af95-87279a2dc05e", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.039492Z", "iopub.status.idle": "2025-05-08T19:15:04.039810Z", "shell.execute_reply": "2025-05-08T19:15:04.039704Z", "shell.execute_reply.started": "2025-05-08T19:15:04.039693Z" } }, "outputs": [], "source": [ "pred_dict" ] }, { "cell_type": "code", "execution_count": null, "id": "80603718-a168-4e4c-aa55-842dfb20f265", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.041594Z", "iopub.status.idle": "2025-05-08T19:15:04.041970Z", "shell.execute_reply": "2025-05-08T19:15:04.041865Z", "shell.execute_reply.started": "2025-05-08T19:15:04.041854Z" } }, "outputs": [], "source": [ "hf_dataset[\"test\"][index][\"messages\"][1]" ] }, { "cell_type": "code", "execution_count": null, "id": "6d3731c9-4686-453f-8c91-e9477fe5541c", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.043675Z", "iopub.status.idle": "2025-05-08T19:15:04.043977Z", "shell.execute_reply": "2025-05-08T19:15:04.043872Z", "shell.execute_reply.started": "2025-05-08T19:15:04.043861Z" } }, "outputs": [], "source": [ "batch_size = 8\n", "test_predictions = []\n", "\n", "for i in tqdm(range(0, len(hf_dataset[\"test\"]), batch_size)):\n", " batch_samples = hf_dataset[\"test\"][i:i + batch_size][\"messages\"]\n", " batch_meta = test_meta[i:i + batch_size]\n", " prompts = [\n", " pipe.tokenizer.apply_chat_template(sample[:1], tokenize=False, add_generation_prompt=True)\n", " for sample in batch_samples\n", " ]\n", " outputs = pipe(prompts, max_new_tokens=2048, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\")], disable_compile=True)\n", "\n", " for index, output in tqdm(enumerate(tqdm(outputs))):\n", " output_dict = {}\n", " start = output[0]['generated_text'].split(r\"model\")[1].strip().find(\"{\")\n", " end = output[0]['generated_text'].split(r\"model\")[1].strip().rfind(\"}\")\n", " try:\n", " pred_dict = json.loads(output[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1])\n", " except:\n", " pred_dict = output[0]['generated_text'].split(r\"model\")[1].strip()[start:end + 1]\n", " \n", " output_dict.update(batch_meta[index])\n", " output_dict[\"predictions\"] = pred_dict\n", " output_dict[\"human-annotation\"] = batch_samples[index][1]['content']\n", " output_dict[\"prompt\"] = batch_samples[index][0]['content']\n", " test_predictions.append(output_dict)" ] }, { "cell_type": "code", "execution_count": null, "id": "616eb30a-eac2-4229-b86c-24eca7534cc6", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.045755Z", "iopub.status.idle": "2025-05-08T19:15:04.046057Z", "shell.execute_reply": "2025-05-08T19:15:04.045954Z", "shell.execute_reply.started": "2025-05-08T19:15:04.045943Z" } }, "outputs": [], "source": [ "with open(\"/root/notebooks/trashspace/gemma_finetuned_expertdata/test_pred.json\", 'w') as json_file:\n", " json.dump(test_predictions, json_file)" ] }, { "cell_type": "code", "execution_count": null, "id": "71480057-d6b9-4499-a8d7-26bf3f3f9342", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ce73604c-4bbb-46c7-8433-d957b0e10405", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4741a722-a772-44bf-949e-e77671a4ef03", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "07c2adef-c2f6-4ba9-98f7-277cce2701d0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1adfa27b-9bfa-4479-be3f-5149a2237c1f", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.047823Z", "iopub.status.idle": "2025-05-08T19:15:04.048130Z", "shell.execute_reply": "2025-05-08T19:15:04.048026Z", "shell.execute_reply.started": "2025-05-08T19:15:04.048015Z" } }, "outputs": [], "source": [ "data = json.loads(test_sample['messages'][1]['content'])\n", "data" ] }, { "cell_type": "code", "execution_count": null, "id": "750d3454-6300-469b-bdc3-77cce45a00ce", "metadata": { "execution": { "iopub.status.busy": "2025-05-08T19:15:04.049897Z", "iopub.status.idle": "2025-05-08T19:15:04.050203Z", "shell.execute_reply": "2025-05-08T19:15:04.050099Z", "shell.execute_reply.started": "2025-05-08T19:15:04.050088Z" } }, "outputs": [], "source": [ "print(len(hf_dataset[\"test\"]))" ] }, { "cell_type": "code", "execution_count": null, "id": "3be45a2d-336f-4899-a8e9-e000437fab8c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "248182ff-bec8-46ff-bc34-14b523d877bf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "timedlibs", "language": "python", "name": "timedlibs" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }