{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "db631ce7", "metadata": {}, "outputs": [], "source": [ "# Initial Classifier Prompt (p0)\n", "target_trainable_instruction = \"\"\"Identify the health literacy level of the following medical text. \n", "Select exactly one label from: [low_health_literacy, intermediate_health_literacy, proficient_health_literacy].\n", "Think about the medical terminology used, sentence complexity, and clarity for a general audience.\"\"\"\n", "\n", "# The specific classification instruction format\n", "classify_raw_instruction = \"\"\"[target_trainable_instruction]\n", "[target_trainable_few_shot_examples]\n", "\n", "Medical Text:\n", "[gen_text]\n", "\n", "Output your classification.\n", "Return the output as a JSON object: {\"prediction\": \"label_here\"}\n", "\"\"\"\n", "\n", "# The \"Gradient\" Prompt (Forward Step)\n", "# This explains why the model misclassified a sample and suggests an instruction update.\n", "training_prompt_forward = \"\"\"In this task, you are an expert linguist. We are using an AI to classify the health literacy level of medical text, but it is making mistakes.\n", "Your job is to analyze the error and suggest how to modify the instruction to fix it.\n", "\n", "Current Instruction:\n", "[target_trainable_instruction]\n", "\n", "Medical Text:\n", "[gen_text]\n", "\n", "AI Predicted Label: [AI_prediction]\n", "Correct Ground Truth Label: [label_summary]\n", "\n", "Requirements for your suggestions:\n", "1) Suggest high-level linguistic criteria (e.g., focus on syllable count, jargon, or tone).\n", "2) Do not include specific examples.\n", "3) Focus only on improving classification accuracy.\n", "\n", "Return the output as a JSON: {\"reasons\": \"...\", \"suggestions\": \"...\"}\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "f3316de5", "metadata": {}, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "def do_classify(target_trainable_instruction, classify_raw_instruction, gen_text, \n", " target_trainable_few_shot_examples='', do_few_shot=False):\n", " # Construct the prompt\n", " instruction = classify_raw_instruction.replace('[target_trainable_instruction]', target_trainable_instruction)\n", " instruction = instruction.replace('[gen_text]', gen_text)\n", " \n", " if do_few_shot:\n", " instruction = instruction.replace('[target_trainable_few_shot_examples]', target_trainable_few_shot_examples)\n", " else:\n", " instruction = instruction.replace('[target_trainable_few_shot_examples]', '')\n", "\n", " # Call OpenAI (or your local vLLM)\n", " response = openai.ChatCompletion.create(\n", " model=\"gpt-5\",\n", " messages=[{\"role\": \"system\", \"content\": instruction}],\n", " )\n", " \n", " try:\n", " content = response[\"choices\"][0][\"message\"][\"content\"]\n", " prediction = json.loads(content, strict=False)['prediction']\n", " return prediction\n", " except:\n", " return \"error\"\n", "\n", "def training_forward_step(training_prompt_forward, target_trainable_instruction, \n", " gen_text, AI_prediction, label_summary):\n", " # Replaces placeholders with the classification error details\n", " instruction = training_prompt_forward.replace('[target_trainable_instruction]', target_trainable_instruction)\n", " instruction = instruction.replace('[gen_text]', gen_text)\n", " instruction = instruction.replace('[AI_prediction]', AI_prediction)\n", " instruction = instruction.replace('[label_summary]', label_summary)\n", "\n", " response = openai.ChatCompletion.create(\n", " model=\"gpt-4\", # High reasoning model recommended for the \"gradient\" step\n", " messages=[{\"role\": \"system\", \"content\": instruction}],\n", " temperature=0\n", " )\n", " return json.loads(response[\"choices\"][0][\"message\"][\"content\"], strict=False)['suggestions']" ] }, { "cell_type": "code", "execution_count": 3, "id": "c3aeae14", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# Load Test Set\n", "with open('/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json', 'r') as f:\n", " test_data = json.load(f)\n", "eval_df = pd.DataFrame(test_data)\n", "\n", "# Load Few-shot Data (For the training pool)\n", "with open('/home/mshahidul/readctrl/data/new_exp/few_shot_examples.json', 'r') as f:\n", " few_shot_json = json.load(f)\n", "\n", "# Flatten the categories into one training pool\n", "all_train_records = []\n", "for category in few_shot_json:\n", " for record in few_shot_json[category]:\n", " # Ensure the 'label' matches the category key for training\n", " record['label_actual'] = category \n", " all_train_records.append(record)\n", "train_df = pd.DataFrame(all_train_records)" ] }, { "cell_type": "code", "execution_count": 4, "id": "6ed53650", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score, classification_report, f1_score\n", "\n", "class ClassificationEval:\n", " def __init__(self, labels=['low_health_literacy', 'intermediate_health_literacy', 'proficient_health_literacy']):\n", " self.target_names = labels\n", "\n", " def run_evaluation(self, labels, preds):\n", " \"\"\"\n", " Calculates accuracy and F1 score for the classification task.\n", " \"\"\"\n", " # Filter out errors or invalid labels to prevent crash\n", " valid_indices = [i for i, p in enumerate(preds) if p in self.target_names]\n", " \n", " filtered_labels = [labels[i] for i in valid_indices]\n", " filtered_preds = [preds[i] for i in valid_indices]\n", "\n", " results = {\n", " \"accuracy\": accuracy_score(filtered_labels, filtered_preds),\n", " \"f1_macro\": f1_score(filtered_labels, filtered_preds, average='macro'),\n", " \"valid_count\": len(filtered_preds),\n", " \"total_count\": len(preds)\n", " }\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": 5, "id": "71c544b7", "metadata": {}, "outputs": [], "source": [ "def eval_loop(eval_df, target_trainable_instruction, classify_raw_instruction, \n", " target_trainable_few_shot_examples, do_few_shot, classifier_eval):\n", " preds = []\n", " labels = []\n", " \n", " for i in tqdm(range(eval_df.shape[0]), desc=\"Evaluating Readability\"):\n", " row = eval_df.iloc[i]\n", " gen_text = row['gen_text'] # The medical text to classify\n", " ground_truth = row['label'] # The actual literacy level\n", " \n", " try:\n", " # Predict using the current prompt version\n", " prediction = do_classify(\n", " target_trainable_instruction, \n", " classify_raw_instruction, \n", " gen_text,\n", " target_trainable_few_shot_examples, \n", " do_few_shot\n", " )\n", " preds.append(prediction)\n", " labels.append(ground_truth)\n", " except Exception as e:\n", " print(f\"Error at row {i}: {e}\")\n", " continue\n", "\n", " # Calculate classification metrics\n", " metrics = classifier_eval.run_evaluation(labels, preds)\n", " \n", " # Format for logging\n", " eval_dict = {k: round(v, 4) if isinstance(v, float) else v for k, v in metrics.items()}\n", " eval_dict['labels'] = labels\n", " eval_dict['preds'] = preds\n", "\n", " return eval_dict" ] }, { "cell_type": "code", "execution_count": null, "id": "aa91a214", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "un", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 5 }