File size: 5,550 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f068d454",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "import json\n",
    "from typing import Literal\n",
    "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n",
    "from dspy.evaluate import Evaluate\n",
    "\n",
    "# --- 1. LLM Configuration ---\n",
    "api_file = \"/home/mshahidul/api_new.json\"\n",
    "with open(api_file, \"r\") as f:\n",
    "    api_keys = json.load(f)\n",
    "openai_api_key = api_keys[\"openai\"]\n",
    "\n",
    "# Student: Local vLLM (Deployment Model)\n",
    "vllm_model = dspy.LM(\n",
    "    model='openai/Qwen/Qwen3-30B-A3B-Instruct-2507',\n",
    "    api_base=\"http://172.16.34.29:8004/v1\",\n",
    "    api_key=\"EMPTY\",\n",
    "    temperature=0.0\n",
    ")\n",
    "\n",
    "# Teacher: OpenAI (High-quality rationale generation)\n",
    "openai_model = dspy.LM(model='gpt-5', api_key=openai_api_key, temperature=0.0)\n",
    "\n",
    "dspy.configure(lm=openai_model)  # Default to OpenAI for optimization\n",
    "\n",
    "# --- 2. Data Processing & Deduplication ---\n",
    "\n",
    "# 2.1 Load Training Data (Few-Shot)\n",
    "with open(\"/home/mshahidul/readctrl/data/new_exp/few_shot_examples.json\", 'r') as f:\n",
    "    few_shot_data = json.load(f)\n",
    "\n",
    "trainset = []\n",
    "train_identifiers = set()\n",
    "\n",
    "for label_key, examples in few_shot_data.items():\n",
    "    for ex in examples:\n",
    "        # Create a unique ID to prevent data leakage\n",
    "        unique_id = f\"{ex['doc_id']}_{label_key}\"\n",
    "        train_identifiers.add(unique_id)\n",
    "        \n",
    "        # In few_shot, 'text' is the summary we want to judge\n",
    "        trainset.append(dspy.Example(\n",
    "            summary_text=ex['gen_text'], \n",
    "            label=label_key\n",
    "        ).with_inputs('summary_text'))\n",
    "\n",
    "# 2.2 Load Dev Data (Filtered)\n",
    "with open(\"/home/mshahidul/readctrl/data/new_exp/cleaned_health_literacy_data.json\", 'r') as f:\n",
    "    main_data = json.load(f)\n",
    "\n",
    "devset = []\n",
    "for item in main_data:\n",
    "    unique_id = f\"{item['doc_id']}_{item['label']}\"\n",
    "    \n",
    "    # Only add to devset if it wasn't used in training\n",
    "    if unique_id not in train_identifiers:\n",
    "        # Based on your update: 'gen_text' or 'text' is the generated summary\n",
    "        # We use 'gen_text' here as the summary to be judged\n",
    "        devset.append(dspy.Example(\n",
    "            summary_text=item['gen_text'], \n",
    "            label=item['label']\n",
    "        ).with_inputs('summary_text'))\n",
    "\n",
    "# Cap devset for efficiency during optimization\n",
    "devset = devset\n",
    "\n",
    "print(f\"Dataset Stats: Train={len(trainset)}, Dev={len(devset)}\")\n",
    "\n",
    "# --- 3. Robust Signature & Module ---\n",
    "\n",
    "class HealthLiteracySignature(dspy.Signature):\n",
    "    \"\"\"\n",
    "    Judge the health literacy level of a generated medical summary.\n",
    "    Identify if the language is suitable for a layperson (low) or requires medical expertise (proficient).\n",
    "    \"\"\"\n",
    "    summary_text: str = dspy.InputField(desc=\"The generated medical summary to be analyzed.\")\n",
    "    reasoning: str = dspy.OutputField(desc=\"Analysis of jargon, acronyms, and sentence complexity.\")\n",
    "    label: Literal[\"low_health_literacy\", \"intermediate_health_literacy\", \"proficient_health_literacy\"] = dspy.OutputField()\n",
    "\n",
    "class HealthLiteracyClassifier(dspy.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # ChainOfThought generates the reasoning field before the label\n",
    "        self.predictor = dspy.ChainOfThought(HealthLiteracySignature)\n",
    "\n",
    "    def forward(self, summary_text):\n",
    "        return self.predictor(summary_text=summary_text)\n",
    "\n",
    "# --- 4. Metric and Optimization ---\n",
    "\n",
    "def health_literacy_metric(gold, pred, trace=None):\n",
    "    if not pred.label: return False\n",
    "    return gold.label.strip().lower() == pred.label.strip().lower()\n",
    "\n",
    "# BootstrapFewShotWithRandomSearch explores different demonstration combinations\n",
    "optimizer = BootstrapFewShotWithRandomSearch(\n",
    "    metric=health_literacy_metric,\n",
    "    max_bootstrapped_demos=3,\n",
    "    num_candidate_programs=8, \n",
    "    teacher_settings=dict(lm=openai_model)\n",
    ")\n",
    "\n",
    "# Compile using the local model, but with OpenAI generating the logic\n",
    "optimized_program = optimizer.compile(HealthLiteracyClassifier(), trainset=trainset)\n",
    "\n",
    "# --- 5. Evaluation & Saving ---\n",
    "\n",
    "evaluator = Evaluate(devset=devset, metric=health_literacy_metric, num_threads=1, display_progress=True)\n",
    "accuracy = evaluator(optimized_program)\n",
    "\n",
    "print(f\"\\nOptimization Complete.\")\n",
    "print(f\"Final Accuracy on Unseen Dev Set: {accuracy.score}%\")\n",
    "# print(f\"Final Accuracy on Unseen Dev Set: {accuracy * 100:.2f}%\")\n",
    "\n",
    "# Save the finalized prompt logic\n",
    "optimized_program.save(\"/home/mshahidul/readctrl/data/new_exp/optimized_health_classifier.json\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}