File size: 5,550 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f068d454",
"metadata": {},
"outputs": [],
"source": [
"import dspy\n",
"import json\n",
"from typing import Literal\n",
"from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n",
"from dspy.evaluate import Evaluate\n",
"\n",
"# --- 1. LLM Configuration ---\n",
"api_file = \"/home/mshahidul/api_new.json\"\n",
"with open(api_file, \"r\") as f:\n",
" api_keys = json.load(f)\n",
"openai_api_key = api_keys[\"openai\"]\n",
"\n",
"# Student: Local vLLM (Deployment Model)\n",
"vllm_model = dspy.LM(\n",
" model='openai/Qwen/Qwen3-30B-A3B-Instruct-2507',\n",
" api_base=\"http://172.16.34.29:8004/v1\",\n",
" api_key=\"EMPTY\",\n",
" temperature=0.0\n",
")\n",
"\n",
"# Teacher: OpenAI (High-quality rationale generation)\n",
"openai_model = dspy.LM(model='gpt-5', api_key=openai_api_key, temperature=0.0)\n",
"\n",
"dspy.configure(lm=openai_model) # Default to OpenAI for optimization\n",
"\n",
"# --- 2. Data Processing & Deduplication ---\n",
"\n",
"# 2.1 Load Training Data (Few-Shot)\n",
"with open(\"/home/mshahidul/readctrl/data/new_exp/few_shot_examples.json\", 'r') as f:\n",
" few_shot_data = json.load(f)\n",
"\n",
"trainset = []\n",
"train_identifiers = set()\n",
"\n",
"for label_key, examples in few_shot_data.items():\n",
" for ex in examples:\n",
" # Create a unique ID to prevent data leakage\n",
" unique_id = f\"{ex['doc_id']}_{label_key}\"\n",
" train_identifiers.add(unique_id)\n",
" \n",
" # In few_shot, 'text' is the summary we want to judge\n",
" trainset.append(dspy.Example(\n",
" summary_text=ex['gen_text'], \n",
" label=label_key\n",
" ).with_inputs('summary_text'))\n",
"\n",
"# 2.2 Load Dev Data (Filtered)\n",
"with open(\"/home/mshahidul/readctrl/data/new_exp/cleaned_health_literacy_data.json\", 'r') as f:\n",
" main_data = json.load(f)\n",
"\n",
"devset = []\n",
"for item in main_data:\n",
" unique_id = f\"{item['doc_id']}_{item['label']}\"\n",
" \n",
" # Only add to devset if it wasn't used in training\n",
" if unique_id not in train_identifiers:\n",
" # Based on your update: 'gen_text' or 'text' is the generated summary\n",
" # We use 'gen_text' here as the summary to be judged\n",
" devset.append(dspy.Example(\n",
" summary_text=item['gen_text'], \n",
" label=item['label']\n",
" ).with_inputs('summary_text'))\n",
"\n",
"# Cap devset for efficiency during optimization\n",
"devset = devset\n",
"\n",
"print(f\"Dataset Stats: Train={len(trainset)}, Dev={len(devset)}\")\n",
"\n",
"# --- 3. Robust Signature & Module ---\n",
"\n",
"class HealthLiteracySignature(dspy.Signature):\n",
" \"\"\"\n",
" Judge the health literacy level of a generated medical summary.\n",
" Identify if the language is suitable for a layperson (low) or requires medical expertise (proficient).\n",
" \"\"\"\n",
" summary_text: str = dspy.InputField(desc=\"The generated medical summary to be analyzed.\")\n",
" reasoning: str = dspy.OutputField(desc=\"Analysis of jargon, acronyms, and sentence complexity.\")\n",
" label: Literal[\"low_health_literacy\", \"intermediate_health_literacy\", \"proficient_health_literacy\"] = dspy.OutputField()\n",
"\n",
"class HealthLiteracyClassifier(dspy.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # ChainOfThought generates the reasoning field before the label\n",
" self.predictor = dspy.ChainOfThought(HealthLiteracySignature)\n",
"\n",
" def forward(self, summary_text):\n",
" return self.predictor(summary_text=summary_text)\n",
"\n",
"# --- 4. Metric and Optimization ---\n",
"\n",
"def health_literacy_metric(gold, pred, trace=None):\n",
" if not pred.label: return False\n",
" return gold.label.strip().lower() == pred.label.strip().lower()\n",
"\n",
"# BootstrapFewShotWithRandomSearch explores different demonstration combinations\n",
"optimizer = BootstrapFewShotWithRandomSearch(\n",
" metric=health_literacy_metric,\n",
" max_bootstrapped_demos=3,\n",
" num_candidate_programs=8, \n",
" teacher_settings=dict(lm=openai_model)\n",
")\n",
"\n",
"# Compile using the local model, but with OpenAI generating the logic\n",
"optimized_program = optimizer.compile(HealthLiteracyClassifier(), trainset=trainset)\n",
"\n",
"# --- 5. Evaluation & Saving ---\n",
"\n",
"evaluator = Evaluate(devset=devset, metric=health_literacy_metric, num_threads=1, display_progress=True)\n",
"accuracy = evaluator(optimized_program)\n",
"\n",
"print(f\"\\nOptimization Complete.\")\n",
"print(f\"Final Accuracy on Unseen Dev Set: {accuracy.score}%\")\n",
"# print(f\"Final Accuracy on Unseen Dev Set: {accuracy * 100:.2f}%\")\n",
"\n",
"# Save the finalized prompt logic\n",
"optimized_program.save(\"/home/mshahidul/readctrl/data/new_exp/optimized_health_classifier.json\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|