{ "cells": [ { "cell_type": "code", "execution_count": 10, "id": "8b3ee6e2-ca9c-40fa-b4c6-a9596f075f79", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:36:47.763713Z", "iopub.status.busy": "2025-05-09T17:36:47.763339Z", "iopub.status.idle": "2025-05-09T17:36:47.768648Z", "shell.execute_reply": "2025-05-09T17:36:47.768166Z", "shell.execute_reply.started": "2025-05-09T17:36:47.763676Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: OPENAI_API_KEY=\"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\"\n" ] } ], "source": [ "import argparse\n", "import concurrent\n", "from dotenv import load_dotenv\n", "from tqdm import tqdm\n", "import textgrad as tg\n", "from textgrad.tasks import load_task\n", "import numpy as np\n", "import random\n", "load_dotenv(override=True)\n", "import os\n", "import json\n", "\n", "%env OPENAI_API_KEY=\"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\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "4ec9a29b-9162-4fe3-b32d-4de4397c6483", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:33:04.417822Z", "iopub.status.busy": "2025-05-09T17:33:04.417437Z", "iopub.status.idle": "2025-05-09T17:33:04.429505Z", "shell.execute_reply": "2025-05-09T17:33:04.429029Z", "shell.execute_reply.started": "2025-05-09T17:33:04.417795Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]\n" ] } ], "source": [ "data_path = \"/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/\"\n", "json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]\n", "\n", "training_samples = []\n", "for json_file in tqdm(json_files):\n", " with open(json_file, 'r') as file:\n", " data = json.load(file)\n", " sampled_items = random.sample(data[\"data\"], 20)\n", " training_samples.extend(sampled_items)\n", "\n", "datapoints = []\n", "\n", "for sample in training_samples:\n", " datapoint = {\"input\":{}}\n", " datapoint[\"input\"][\"src_text\"] = sample[\"main_src_text\"]\n", " datapoint[\"input\"][\"tgt_text\"] = sample[\"tgt_text\"]\n", " datapoint[\"input\"][\"src_prev\"] = sample[\"tt_src_prev\"]\n", " datapoint[\"input\"][\"src_next\"] = sample[\"tt_src_next\"]\n", " datapoint[\"input\"][\"tgt_prev\"] = sample[\"tt_tgt_prev\"]\n", " datapoint[\"input\"][\"tgt_next\"] = sample[\"tt_tgt_next\"]\n", " datapoint[\"input\"][\"src_lang\"] = sample[\"src_lang\"]\n", " datapoint[\"input\"][\"tgt_lang\"] = sample[\"tgt_lang\"]\n", " datapoint[\"evaluation\"] = sample[\"labelers\"][0][\"annotation\"]\n", " datapoints.append(datapoint)" ] }, { "cell_type": "code", "execution_count": 5, "id": "a894ce72-d451-44fa-aaa5-85bf8e6dc9da", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:33:40.240759Z", "iopub.status.busy": "2025-05-09T17:33:40.240243Z", "iopub.status.idle": "2025-05-09T17:33:40.244435Z", "shell.execute_reply": "2025-05-09T17:33:40.243818Z", "shell.execute_reply.started": "2025-05-09T17:33:40.240720Z" } }, "outputs": [], "source": [ "def set_seed(seed):\n", " np.random.seed(seed)\n", " random.seed(seed)" ] }, { "cell_type": "code", "execution_count": 6, "id": "4eeaa266-3ca2-4360-b80b-b38aa3bbdb70", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:33:55.982807Z", "iopub.status.busy": "2025-05-09T17:33:55.982080Z", "iopub.status.idle": "2025-05-09T17:33:55.988522Z", "shell.execute_reply": "2025-05-09T17:33:55.987924Z", "shell.execute_reply.started": "2025-05-09T17:33:55.982770Z" } }, "outputs": [], "source": [ "def eval_sample(item, eval_fn, model):\n", " \"\"\"\n", " This function allows us to evaluate if an answer to a question in the prompt is a good answer.\n", "\n", " \"\"\"\n", " x, y = item\n", " x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n", " y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n", " response = model(x)\n", " try:\n", " eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n", " return int(eval_output_variable.value)\n", " except:\n", " eval_output_variable = eval_fn([x, y, response])\n", " eval_output_parsed = eval_fn.parse_output(eval_output_variable)\n", " return int(eval_output_parsed)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c7e57f9d-c0ff-4139-9e61-b93510599353", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:34:08.606301Z", "iopub.status.busy": "2025-05-09T17:34:08.605538Z", "iopub.status.idle": "2025-05-09T17:34:08.612515Z", "shell.execute_reply": "2025-05-09T17:34:08.611911Z", "shell.execute_reply.started": "2025-05-09T17:34:08.606262Z" } }, "outputs": [], "source": [ "def eval_dataset(test_set, eval_fn, model, max_samples: int=None):\n", " if max_samples is None:\n", " max_samples = len(test_set)\n", " accuracy_list = []\n", " with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:\n", " futures = []\n", " for _, sample in enumerate(test_set):\n", " \n", " future = executor.submit(eval_sample, sample, eval_fn, model)\n", " futures.append(future)\n", " if len(futures) >= max_samples:\n", " break\n", " tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)\n", " for future in tqdm_loader:\n", " acc_item = future.result()\n", " accuracy_list.append(acc_item)\n", " tqdm_loader.set_description(f\"Accuracy: {np.mean(accuracy_list)}\")\n", " return accuracy_list " ] }, { "cell_type": "code", "execution_count": 8, "id": "039af9f3-a124-4a50-98a7-e728a913c069", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:34:22.703336Z", "iopub.status.busy": "2025-05-09T17:34:22.702980Z", "iopub.status.idle": "2025-05-09T17:34:22.707253Z", "shell.execute_reply": "2025-05-09T17:34:22.706781Z", "shell.execute_reply.started": "2025-05-09T17:34:22.703313Z" } }, "outputs": [], "source": [ "def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):\n", " val_performance = np.mean(eval_dataset(val_set, eval_fn, model))\n", " previous_performance = np.mean(results[\"validation_acc\"][-1])\n", " print(\"val_performance: \", val_performance)\n", " print(\"previous_performance: \", previous_performance)\n", " previous_prompt = results[\"prompt\"][-1]\n", " \n", " if val_performance < previous_performance:\n", " print(f\"rejected prompt: {system_prompt.value}\")\n", " system_prompt.set_value(previous_prompt)\n", " val_performance = previous_performance\n", "\n", " results[\"validation_acc\"].append(val_performance)" ] }, { "cell_type": "code", "execution_count": 14, "id": "031ebb6e-f5ff-45b0-a810-d1bd81ef6d2a", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:40:38.476352Z", "iopub.status.busy": "2025-05-09T17:40:38.475979Z", "iopub.status.idle": "2025-05-09T17:40:38.701947Z", "shell.execute_reply": "2025-05-09T17:40:38.701394Z", "shell.execute_reply.started": "2025-05-09T17:40:38.476327Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train/Val/Test Set Lengths: 50 100 100\n" ] } ], "source": [ "set_seed(12)\n", "llm_api_eval = tg.get_engine(engine_name=\"gpt-4o\")\n", "llm_api_test = tg.get_engine(engine_name=\"gpt-3.5-turbo-0125\")\n", "tg.set_backward_engine(llm_api_eval, override=True)\n", "\n", "# Load the data and the evaluation function\n", "train_set, val_set, test_set, eval_fn = load_task(\"BBH_object_counting\", evaluation_api=llm_api_eval)\n", "print(\"Train/Val/Test Set Lengths: \", len(train_set), len(val_set), len(test_set))\n", "STARTING_SYSTEM_PROMPT = train_set.get_task_description()" ] }, { "cell_type": "code", "execution_count": 15, "id": "bde34303-2f52-415f-b117-264e266b84f0", "metadata": { "execution": { "iopub.execute_input": "2025-05-09T17:40:39.330651Z", "iopub.status.busy": "2025-05-09T17:40:39.330285Z", "iopub.status.idle": "2025-05-09T17:40:39.398820Z", "shell.execute_reply": "2025-05-09T17:40:39.398116Z", "shell.execute_reply.started": "2025-05-09T17:40:39.330626Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/100 [00:00", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mTextualGradientDescent(engine\u001b[38;5;241m=\u001b[39mllm_api_eval, parameters\u001b[38;5;241m=\u001b[39m[system_prompt])\n\u001b[1;32m 17\u001b[0m results \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m: [], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m: []}\n\u001b[0;32m---> 18\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(\u001b[43meval_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(eval_dataset(val_set, eval_fn, model))\n\u001b[1;32m 20\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(system_prompt\u001b[38;5;241m.\u001b[39mget_value())\n", "Cell \u001b[0;32mIn[7], line 15\u001b[0m, in \u001b[0;36meval_dataset\u001b[0;34m(test_set, eval_fn, model, max_samples)\u001b[0m\n\u001b[1;32m 13\u001b[0m tqdm_loader \u001b[38;5;241m=\u001b[39m tqdm(concurrent\u001b[38;5;241m.\u001b[39mfutures\u001b[38;5;241m.\u001b[39mas_completed(futures), total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(futures), position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m future \u001b[38;5;129;01min\u001b[39;00m tqdm_loader:\n\u001b[0;32m---> 15\u001b[0m acc_item \u001b[38;5;241m=\u001b[39m \u001b[43mfuture\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m accuracy_list\u001b[38;5;241m.\u001b[39mappend(acc_item)\n\u001b[1;32m 17\u001b[0m tqdm_loader\u001b[38;5;241m.\u001b[39mset_description(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAccuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(accuracy_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:451\u001b[0m, in \u001b[0;36mFuture.result\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;241m==\u001b[39m FINISHED:\n\u001b[0;32m--> 451\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition\u001b[38;5;241m.\u001b[39mwait(timeout)\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n", "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/_base.py:403\u001b[0m, in \u001b[0;36mFuture.__get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception:\n\u001b[1;32m 402\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 403\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 405\u001b[0m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/apps/python3.10/lib/python3.10/concurrent/futures/thread.py:58\u001b[0m, in \u001b[0;36m_WorkItem.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuture\u001b[38;5;241m.\u001b[39mset_exception(exc)\n", "Cell \u001b[0;32mIn[6], line 8\u001b[0m, in \u001b[0;36meval_sample\u001b[0;34m(item, eval_fn, model)\u001b[0m\n\u001b[1;32m 6\u001b[0m x, y \u001b[38;5;241m=\u001b[39m item\n\u001b[1;32m 7\u001b[0m x \u001b[38;5;241m=\u001b[39m tg\u001b[38;5;241m.\u001b[39mVariable(x, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, role_description\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery to the language model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mtg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mVariable\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequires_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrole_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcorrect answer for the query\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m response \u001b[38;5;241m=\u001b[39m model(x)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", "File \u001b[0;32m~/notebooks/MT_TQ/Libraries/timedlibs/lib/python3.10/site-packages/textgrad/variable.py:43\u001b[0m, in \u001b[0;36mVariable.__init__\u001b[0;34m(self, value, image_path, predecessors, requires_grad, role_description)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m requires_grad) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(_predecessor_requires_grad) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf the variable does not require grad, none of its predecessors should require grad.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn this case, following predecessors require grad: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m_predecessor_requires_grad\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(value) \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbytes\u001b[39m, \u001b[38;5;28mint\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValue must be a string, int, or image (bytes). Got: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(value))\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 45\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(value)\n", "\u001b[0;31mAssertionError\u001b[0m: Value must be a string, int, or image (bytes). Got: " ] } ], "source": [ "train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)\n", "\n", "\n", "# Testing the 0-shot performance of the evaluation engine\n", "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n", " requires_grad=True, \n", " role_description=\"system prompt to the language model\")\n", "model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)\n", "\n", "system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, \n", " requires_grad=True,\n", " role_description=\"structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task\")\n", "model = tg.BlackboxLLM(llm_api_test, system_prompt)\n", "\n", "optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt])\n", "\n", "results = {\"test_acc\": [], \"prompt\": [], \"validation_acc\": []}\n", "results[\"test_acc\"].append(eval_dataset(test_set, eval_fn, model))\n", "results[\"validation_acc\"].append(eval_dataset(val_set, eval_fn, model))\n", "results[\"prompt\"].append(system_prompt.get_value())" ] }, { "cell_type": "code", "execution_count": null, "id": "47c15231-22ff-459b-b5cc-ca32aaa62332", "metadata": {}, "outputs": [], "source": [ "for epoch in range(3):\n", " for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):\n", " pbar.set_description(f\"Training step {steps}. Epoch {epoch}\")\n", " optimizer.zero_grad()\n", " losses = []\n", " for (x, y) in zip(batch_x, batch_y):\n", " x = tg.Variable(x, requires_grad=False, role_description=\"query to the language model\")\n", " y = tg.Variable(y, requires_grad=False, role_description=\"correct answer for the query\")\n", " response = model(x)\n", " try:\n", " eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))\n", " except:\n", " eval_output_variable = eval_fn([x, y, response])\n", " losses.append(eval_output_variable)\n", " total_loss = tg.sum(losses)\n", " total_loss.backward()\n", " optimizer.step()\n", " \n", " run_validation_revert(system_prompt, results, model, eval_fn, val_set)\n", " \n", " print(\"sys prompt: \", system_prompt)\n", " test_acc = eval_dataset(test_set, eval_fn, model)\n", " results[\"test_acc\"].append(test_acc)\n", " results[\"prompt\"].append(system_prompt.get_value())\n", " if steps == 3:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "3c5e93f5-8d1c-4b87-a6d1-811714982d47", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "67a4583f-162c-4e2d-b061-798f6c676a28", "metadata": {}, "outputs": [], "source": [ "class TranslationQualityAssessor(dspy.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.assess = dspy.ChainOfThought(TranslationQualitySignature)\n", "\n", " def forward(self, src_lang, tgt_lang, src_text, translation, src_prev=\"\", tgt_prev=\"\", src_next=\"\", tgt_next=\"\"):\n", " context = f\"\"\"Previous Context:\n", " Source: {src_prev}\n", " Translation: {tgt_prev}\n", " \n", " Next Context:\n", " Source: {src_next}\n", " Translation: {tgt_next}\"\"\"\n", "\n", " result = self.assess(\n", " context=context,\n", " source=f\"Source ({src_lang}): {src_text}\",\n", " translation=f\"Translation ({tgt_lang}): {translation}\"\n", " )\n", " \n", " return result.evaluation\n", "\n", "class TranslationMetrics:\n", " @staticmethod\n", " def exact_match_score(pred, gold):\n", " try:\n", " pred_json = json.loads(pred)\n", " gold_json = gold\n", " \n", " accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))\n", " readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))\n", " \n", " return (accuracy_match and readability_match)\n", " except:\n", " return False\n", " \n", " @staticmethod\n", " def partial_match_score(pred, gold):\n", " try:\n", " pred_json = json.loads(pred)\n", " gold_json = gold\n", " \n", " # Score comparison\n", " accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))\n", " readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))\n", " \n", " # Issues comparison\n", " pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))\n", " gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))\n", " pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))\n", " gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))\n", " \n", " # Calculate Jaccard similarity for issues\n", " accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))\n", " readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))\n", " \n", " # Combine scores (0.6 weight to scores, 0.4 to issues similarity)\n", " score_component = 1 - ((accuracy_diff + readability_diff) / 8)\n", " issues_component = (accuracy_issues_sim + readability_issues_sim) / 2\n", " \n", " final_score = 0.6 * score_component + 0.4 * issues_component\n", " return max(0, final_score)\n", " except:\n", " return 0\n", "\n", "def prepare_dataset(file_path):\n", " with open(file_path, 'r') as f:\n", " data = json.load(f)\n", " \n", " prepared_data = []\n", " \n", " for item in data:\n", " example = dspy.Example(\n", " context=f\"\"\"Previous Context:\n", " Source: {item['src_prev']}\n", " Translation: {item['tgt_prev']}\n", " \n", " Next Context:\n", " Source: {item['src_next']}\n", " Translation: {item['tgt_next']}\"\"\",\n", " source=f\"Source ({item['src_lang']}): {item['src_text']}\",\n", " translation=f\"Translation ({item['tgt_lang']}): {item['main_text']}\",\n", " evaluation=json.dumps(item['evaluation'], ensure_ascii=False)\n", " ).with_inputs(\"context\", \"source\", \"translation\")\n", " \n", " prepared_data.append(example)\n", " \n", " # Split data: 70% train, 15% dev, 15% test\n", " train_size = int(0.7 * len(prepared_data))\n", " dev_size = int(0.15 * len(prepared_data))\n", " \n", " train_data = prepared_data[:train_size]\n", " dev_data = prepared_data[train_size:train_size + dev_size]\n", " test_data = prepared_data[train_size + dev_size:]\n", " \n", " return train_data, dev_data, test_data\n", "\n", "def optimize_translation_quality_assessment():\n", " # Initialize DSPy\n", " lm = TranslationQualityLM()\n", " dspy.settings.configure(lm=lm)\n", " \n", " # Load and prepare dataset\n", " train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')\n", " \n", " # Create evaluator\n", " evaluator = Evaluate(\n", " metrics={\n", " 'exact_match': TranslationMetrics.exact_match_score,\n", " 'partial_match': TranslationMetrics.partial_match_score\n", " }\n", " )\n", " \n", " # Initialize module\n", " assessor = TranslationQualityAssessor()\n", " \n", " # Initialize MIPROv2 optimizer\n", " optimizer = dspy.MIPROv2(\n", " metric=lambda x: x['partial_match'],\n", " max_rounds=5, # Number of optimization rounds\n", " max_traces=10, # Number of traces per round\n", " max_depth=3, # Maximum depth of reasoning chains\n", " num_candidate_prompts=5, # Number of candidate prompts to generate\n", " num_rounds_per_prompt=3, # Number of rounds per candidate prompt\n", " temperature=0.7,\n", " verbose=True\n", " )\n", " \n", " # Compile the module with optimization\n", " compiled_assessor = optimizer.compile(\n", " assessor,\n", " trainset=train_data,\n", " devset=dev_data,\n", " eval_kwargs={\n", " 'metric': 'partial_match',\n", " 'num_threads': 4,\n", " 'batch_size': 8\n", " }\n", " )\n", " \n", " # Evaluate on test set\n", " results = []\n", " for example in test_data:\n", " pred = compiled_assessor(\n", " context=example.context,\n", " source=example.source,\n", " translation=example.translation\n", " )\n", " \n", " result = evaluator.evaluate(\n", " predictions=[pred],\n", " ground_truth=[example.evaluation]\n", " )\n", " results.append(result)\n", " \n", " # Calculate and print final metrics\n", " avg_exact_match = np.mean([r['exact_match'] for r in results])\n", " avg_partial_match = np.mean([r['partial_match'] for r in results])\n", " \n", " print(f\"Average Exact Match Score: {avg_exact_match:.3f}\")\n", " print(f\"Average Partial Match Score: {avg_partial_match:.3f}\")\n", " \n", " return compiled_assessor\n", "\n", "if __name__ == \"__main__\":\n", " optimized_assessor = optimize_translation_quality_assessment()" ] } ], "metadata": { "kernelspec": { "display_name": "timedlibs", "language": "python", "name": "timedlibs" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }