{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e6d20008-a91c-4618-baa0-5991e031f1bd", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:48:57.985184Z", "iopub.status.busy": "2025-05-13T21:48:57.984795Z", "iopub.status.idle": "2025-05-13T21:51:48.369715Z", "shell.execute_reply": "2025-05-13T21:51:48.368907Z", "shell.execute_reply.started": "2025-05-13T21:48:57.985144Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq\n", "import torch\n", "from peft import LoraConfig, get_peft_model\n", "\n", "import os\n", "from tqdm import tqdm\n", "import json\n", "\n", "import random\n", "from datasets import load_dataset\n", "from datasets import Dataset, DatasetDict" ] }, { "cell_type": "code", "execution_count": 3, "id": "67f95fc8-a9d8-48cf-a551-7d30781cdb55", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:53:58.075473Z", "iopub.status.busy": "2025-05-13T21:53:58.074767Z", "iopub.status.idle": "2025-05-13T21:53:58.767860Z", "shell.execute_reply": "2025-05-13T21:53:58.767319Z", "shell.execute_reply.started": "2025-05-13T21:53:58.075446Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 8/8 [00:00<00:00, 22.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['messages'],\n", " num_rows: 309\n", " })\n", " test: Dataset({\n", " features: ['messages'],\n", " num_rows: 343\n", " })\n", "})\n" ] } ], "source": [ "data_path = (\n", " \"/root/notebooks/MT_TQ/Caches/May2025/tquality.annotated.data/parsed/pldl/\"\n", ")\n", "\n", "json_files = [\n", " os.path.join(root, file)\n", " for root, _, files in os.walk(data_path)\n", " for file in files\n", " if file.endswith(\".json\")\n", "]\n", "\n", "training_samples = []\n", "testing_samples = []\n", "\n", "for json_file in tqdm(json_files):\n", " with open(json_file, \"r\") as file:\n", " data = json.load(file)\n", " sampled_items = data[\"data\"]\n", " if \"test\" in json_file:\n", " testing_samples.extend(sampled_items)\n", " if \"train\" in json_file:\n", " training_samples.extend(sampled_items)\n", "\n", "training_datapoints = []\n", "testing_datapoints = []\n", "\n", "for idx, sample in enumerate(training_samples):\n", " datapoint = {\"input\": {}}\n", " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n", " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n", " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n", " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n", " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n", " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n", " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n", " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n", " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n", " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n", " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n", " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n", " datapoint[\"input\"][\"id\"] = idx\n", " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n", " training_datapoints.append(datapoint)\n", "\n", "for idx, sample in enumerate(testing_samples):\n", " datapoint = {\"input\": {}}\n", " datapoint[\"input\"][\"src_text\"] = sample[\"src_text\"]\n", " datapoint[\"input\"][\"tgt_text\"] = sample[\"main_tgt_text\"]\n", " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n", " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n", " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n", " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n", " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n", " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n", " datapoint[\"input\"][\"start_frame\"] = sample[\"start_frame\"]\n", " datapoint[\"input\"][\"end_frame\"] = sample[\"end_frame\"]\n", " datapoint[\"input\"][\"title_id\"] = sample[\"title_id\"]\n", " datapoint[\"input\"][\"alt_tgt_text\"]= sample[\"alt_tgt_text\"]\n", " datapoint[\"input\"][\"id\"] = idx\n", " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n", " testing_datapoints.append(datapoint)\n", "\n", "system_message = \"You are a helpful assistant who is an expert in estimating quality of translations.\"\n", "\n", "output_template = '''\n", "{\n", " \"Accuracy Issues\": [\n", " {\n", " \"Error Span\": \"\",\n", " \"Error Explanation\": \"\",\n", " \"Error Quality Category\": \"\",\n", " \"Error Quality Tags\": [],\n", " \"Error Severity\": \"\"\n", " }\n", " ],\n", " \"Accuracy Score\": \"\",\n", " \"Readability Issues\": [\n", " {\n", " \"Error Span\": \"\",\n", " \"Error Explanation\": \"\",\n", " \"Error Quality Category\": \"\",\n", " \"Error Quality Tags\": [],\n", " \"Error Severity\": \"\"\n", " }\n", " ],\n", " \"Readability Score\": \"\"\n", "}'''\n", "\n", "def create_conversation(input_sample, output_sample):\n", " return {\n", " \"messages\": [\n", " # {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": input_sample},\n", " {\"role\": \"assistant\", \"content\": output_sample}\n", " ]\n", " }\n", "\n", "def create_dataset(datapoints, template_string):\n", " dataset = []\n", " meta = []\n", " for datapoint in datapoints:\n", " src_text = datapoint['input']['src_text']\n", " tgt_text = datapoint['input']['tgt_text']\n", " src_prev = datapoint['input']['src_prev']\n", " src_next = datapoint['input']['src_next'] \n", " tgt_prev = datapoint['input']['tgt_prev']\n", " tgt_next = datapoint['input']['tgt_next']\n", " src_lang = datapoint['input']['src_lang']\n", " tgt_lang = datapoint['input']['tgt_lang']\n", " \n", " start_frame = datapoint['input']['start_frame']\n", " end_frame = datapoint['input']['end_frame']\n", " title_id = datapoint['input']['title_id']\n", " output = datapoint['evaluation']\n", " idx = datapoint['input']['id']\n", " if len(output['Accuracy Issues']) != 0 or len(output['Readability Issues']) != 0:\n", " item = template_string.format(src_text=src_text, tgt_text=tgt_text, \n", " src_prev=src_prev, src_next=src_next, \n", " tgt_prev=tgt_prev, tgt_next=tgt_next, \n", " src_lang=src_lang, tgt_lang=tgt_lang,\n", " template=output_template)\n", " \n", " dataset.append(create_conversation(item, json.dumps(output)))\n", " meta.append({\"id\": idx, \"start_frame\": start_frame, \"end_frame\": end_frame, \"title_id\": title_id})\n", " \n", " return dataset, meta\n", " \n", "def dataset_prep(datapoints):\n", " with open(\"prompts.txt\") as file:\n", " template_string = file.read()\n", " dataset, meta = create_dataset(datapoints, template_string)\n", " return dataset, meta\n", "\n", "train_dataset, train_meta = dataset_prep(training_datapoints)\n", "test_dataset, test_meta = dataset_prep(testing_datapoints)\n", "\n", "dataset = {\"train\": train_dataset, \"test\": test_dataset}\n", "\n", "def convert_to_hf_dataset(dataset):\n", " train_dataset = Dataset.from_list(dataset['train'])\n", " test_dataset = Dataset.from_list(dataset['test'])\n", " \n", " hf_dataset = DatasetDict({\n", " 'train': train_dataset,\n", " 'test': test_dataset\n", " })\n", " \n", " return hf_dataset\n", "\n", "hf_dataset = convert_to_hf_dataset(dataset)\n", "print(hf_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "id": "8b52f143-1077-4da6-ac92-b1dce5cdc17c", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:54:12.568533Z", "iopub.status.busy": "2025-05-13T21:54:12.568078Z", "iopub.status.idle": "2025-05-13T21:54:49.724121Z", "shell.execute_reply": "2025-05-13T21:54:49.723481Z", "shell.execute_reply.started": "2025-05-13T21:54:12.568507Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 12/12 [00:18<00:00, 1.58s/it]\n" ] } ], "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig\n", "from transformers import AutoProcessor, Gemma3ForConditionalGeneration\n", "device = torch.device(\"cuda:0\")\n", "\n", "# Hugging Face model id\n", "model_id = \"google/gemma-3-27b-it\" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`\n", "\n", "# Select model class based on id\n", "if model_id == \"google/gemma-3-27b-it\":\n", " model_class = Gemma3ForConditionalGeneration\n", "else:\n", " model_class = AutoModelForImageTextToText\n", "\n", "torch_dtype = torch.bfloat16\n", "\n", "model_kwargs = dict(\n", " attn_implementation=\"eager\",\n", " torch_dtype=torch_dtype,\n", " device_map=\"auto\",\n", ")\n", "\n", "model = model_class.from_pretrained(model_id, **model_kwargs)\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-3-27b-it\") # Load the Instruction Tokenizer to use the official Gemma template" ] }, { "cell_type": "code", "execution_count": 5, "id": "8443dfd8-6193-480c-9937-f6e0c43a9f56", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:12.713958Z", "iopub.status.busy": "2025-05-13T21:55:12.713495Z", "iopub.status.idle": "2025-05-13T21:55:12.717707Z", "shell.execute_reply": "2025-05-13T21:55:12.717199Z", "shell.execute_reply.started": "2025-05-13T21:55:12.713930Z" } }, "outputs": [], "source": [ "from peft import LoraConfig\n", "\n", "peft_config = LoraConfig(\n", " lora_alpha=128,\n", " lora_dropout=0.05,\n", " r=16,\n", " bias=\"none\",\n", " target_modules=\"all-linear\",\n", " task_type=\"CAUSAL_LM\",\n", " modules_to_save=[\"lm_head\", \"embed_tokens\"] # make sure to save the lm_head and embed_tokens as you train the special tokens\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "8f2b8371-ba1b-44ff-9462-d0c90335f82a", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:22.076515Z", "iopub.status.busy": "2025-05-13T21:55:22.076029Z", "iopub.status.idle": "2025-05-13T21:55:22.783524Z", "shell.execute_reply": "2025-05-13T21:55:22.782937Z", "shell.execute_reply.started": "2025-05-13T21:55:22.076489Z" } }, "outputs": [], "source": [ "from trl import SFTConfig\n", "\n", "args = SFTConfig(\n", " output_dir=\"may13-gemma-27b-tq_sft_finetuned-model\",\n", " max_seq_length=2048,\n", " packing=True,\n", " num_train_epochs=1,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " gradient_checkpointing=True,\n", " optim=\"adamw_torch_fused\",\n", " logging_steps=1,\n", " save_strategy=\"epoch\",\n", " learning_rate=1e-4,\n", " fp16=True if torch_dtype == torch.float16 else False,\n", " bf16=True if torch_dtype == torch.bfloat16 else False,\n", " max_grad_norm=0.3,\n", " warmup_ratio=0.03,\n", " lr_scheduler_type=\"constant\",\n", " push_to_hub=True,\n", " report_to=\"tensorboard\",\n", " dataset_kwargs={\n", " \"add_special_tokens\": False,\n", " \"append_concat_token\": True,\n", " },\n", " no_cuda=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "2be55b87-70c9-4973-b0db-33154c272e47", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T21:55:25.765385Z", "iopub.status.busy": "2025-05-13T21:55:25.764949Z", "iopub.status.idle": "2025-05-13T21:55:36.592163Z", "shell.execute_reply": "2025-05-13T21:55:36.591614Z", "shell.execute_reply.started": "2025-05-13T21:55:25.765360Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Converting train dataset to ChatML: 100%|██████████| 309/309 [00:00<00:00, 9533.70 examples/s]\n", "Applying chat template to train dataset: 100%|██████████| 309/309 [00:00<00:00, 4443.06 examples/s]\n", "Tokenizing train dataset: 100%|██████████| 309/309 [00:01<00:00, 226.22 examples/s]\n", "Packing train dataset: 100%|██████████| 309/309 [00:00<00:00, 102364.74 examples/s]\n", "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n" ] } ], "source": [ "from trl import SFTTrainer\n", "\n", "# Create Trainer object\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=args,\n", " train_dataset=hf_dataset[\"train\"],\n", " peft_config=peft_config,\n", " processing_class=tokenizer\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "d8d82767-27ed-48ed-ad22-3f3cf2dff15e", "metadata": { "execution": { "iopub.execute_input": "2025-05-13T22:00:25.107226Z", "iopub.status.busy": "2025-05-13T22:00:25.106569Z", "iopub.status.idle": "2025-05-13T22:27:35.945604Z", "shell.execute_reply": "2025-05-13T22:27:35.944775Z", "shell.execute_reply.started": "2025-05-13T22:00:25.107196Z" }, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n" ] }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "
|---|---|
| 1 | \n", "10.801900 | \n", "
| 2 | \n", "8.381400 | \n", "
| 3 | \n", "6.970200 | \n", "
| 4 | \n", "5.784300 | \n", "
| 5 | \n", "4.970800 | \n", "
| 6 | \n", "4.389700 | \n", "
| 7 | \n", "4.325000 | \n", "
| 8 | \n", "3.557000 | \n", "
| 9 | \n", "3.357700 | \n", "
| 10 | \n", "3.092500 | \n", "
| 11 | \n", "3.170300 | \n", "
| 12 | \n", "2.648500 | \n", "
| 13 | \n", "3.067800 | \n", "
| 14 | \n", "2.377100 | \n", "
| 15 | \n", "2.847700 | \n", "
| 16 | \n", "2.628800 | \n", "
| 17 | \n", "2.630800 | \n", "
| 18 | \n", "2.820900 | \n", "
| 19 | \n", "2.596700 | \n", "
| 20 | \n", "2.675300 | \n", "
| 21 | \n", "2.846300 | \n", "
| 22 | \n", "2.706700 | \n", "
| 23 | \n", "2.645100 | \n", "
| 24 | \n", "2.214600 | \n", "
| 25 | \n", "2.790700 | \n", "
| 26 | \n", "2.640700 | \n", "
| 27 | \n", "2.908900 | \n", "
| 28 | \n", "2.690400 | \n", "
| 29 | \n", "2.807200 | \n", "
| 30 | \n", "2.713600 | \n", "
| 31 | \n", "2.563200 | \n", "
| 32 | \n", "2.412700 | \n", "
| 33 | \n", "2.627700 | \n", "
| 34 | \n", "2.431800 | \n", "
| 35 | \n", "2.240600 | \n", "
| 36 | \n", "2.650300 | \n", "
| 37 | \n", "2.014900 | \n", "
| 38 | \n", "2.463100 | \n", "
| 39 | \n", "2.283300 | \n", "
| 40 | \n", "2.450500 | \n", "
| 41 | \n", "2.570400 | \n", "
| 42 | \n", "2.550500 | \n", "
| 43 | \n", "2.530600 | \n", "
| 44 | \n", "2.551400 | \n", "
| 45 | \n", "2.383000 | \n", "
| 46 | \n", "2.550500 | \n", "
| 47 | \n", "2.575900 | \n", "
| 48 | \n", "2.494300 | \n", "
| 49 | \n", "2.387200 | \n", "
| 50 | \n", "2.318800 | \n", "
| 51 | \n", "2.365200 | \n", "
| 52 | \n", "2.190100 | \n", "
| 53 | \n", "2.419100 | \n", "
| 54 | \n", "2.290900 | \n", "
| 55 | \n", "2.152500 | \n", "
| 56 | \n", "2.398700 | \n", "
| 57 | \n", "2.982500 | \n", "
| 58 | \n", "2.380200 | \n", "
| 59 | \n", "2.357500 | \n", "
| 60 | \n", "2.386300 | \n", "
| 61 | \n", "2.741300 | \n", "
| 62 | \n", "2.850300 | \n", "
| 63 | \n", "2.682100 | \n", "
| 64 | \n", "2.972100 | \n", "
| 65 | \n", "2.237800 | \n", "
| 66 | \n", "2.518300 | \n", "
| 67 | \n", "2.520700 | \n", "
| 68 | \n", "2.122700 | \n", "
| 69 | \n", "2.210200 | \n", "
| 70 | \n", "2.414000 | \n", "
| 71 | \n", "2.348200 | \n", "
| 72 | \n", "2.470800 | \n", "
| 73 | \n", "2.417400 | \n", "
| 74 | \n", "2.562900 | \n", "
| 75 | \n", "2.286800 | \n", "
| 76 | \n", "2.671400 | \n", "
| 77 | \n", "2.176200 | \n", "
| 78 | \n", "2.284200 | \n", "
| 79 | \n", "2.354700 | \n", "
| 80 | \n", "2.363400 | \n", "
"
],
"text/plain": [
"