{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "f068d454", "metadata": {}, "outputs": [], "source": [ "import dspy\n", "import json\n", "from typing import Literal\n", "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n", "from dspy.evaluate import Evaluate\n", "\n", "# --- 1. LLM Configuration ---\n", "api_file = \"/home/mshahidul/api_new.json\"\n", "with open(api_file, \"r\") as f:\n", " api_keys = json.load(f)\n", "openai_api_key = api_keys[\"openai\"]\n", "\n", "# Student: Local vLLM (Deployment Model)\n", "vllm_model = dspy.LM(\n", " model='openai/Qwen/Qwen3-30B-A3B-Instruct-2507',\n", " api_base=\"http://172.16.34.29:8004/v1\",\n", " api_key=\"EMPTY\",\n", " temperature=0.0\n", ")\n", "\n", "# Teacher: OpenAI (High-quality rationale generation)\n", "openai_model = dspy.LM(model='gpt-5', api_key=openai_api_key, temperature=0.0)\n", "\n", "dspy.configure(lm=openai_model) # Default to OpenAI for optimization\n", "\n", "# --- 2. Data Processing & Deduplication ---\n", "\n", "# 2.1 Load Training Data (Few-Shot)\n", "with open(\"/home/mshahidul/readctrl/data/new_exp/few_shot_examples.json\", 'r') as f:\n", " few_shot_data = json.load(f)\n", "\n", "trainset = []\n", "train_identifiers = set()\n", "\n", "for label_key, examples in few_shot_data.items():\n", " for ex in examples:\n", " # Create a unique ID to prevent data leakage\n", " unique_id = f\"{ex['doc_id']}_{label_key}\"\n", " train_identifiers.add(unique_id)\n", " \n", " # In few_shot, 'text' is the summary we want to judge\n", " trainset.append(dspy.Example(\n", " summary_text=ex['gen_text'], \n", " label=label_key\n", " ).with_inputs('summary_text'))\n", "\n", "# 2.2 Load Dev Data (Filtered)\n", "with open(\"/home/mshahidul/readctrl/data/new_exp/cleaned_health_literacy_data.json\", 'r') as f:\n", " main_data = json.load(f)\n", "\n", "devset = []\n", "for item in main_data:\n", " unique_id = f\"{item['doc_id']}_{item['label']}\"\n", " \n", " # Only add to devset if it wasn't used in training\n", " if unique_id not in train_identifiers:\n", " # Based on your update: 'gen_text' or 'text' is the generated summary\n", " # We use 'gen_text' here as the summary to be judged\n", " devset.append(dspy.Example(\n", " summary_text=item['gen_text'], \n", " label=item['label']\n", " ).with_inputs('summary_text'))\n", "\n", "# Cap devset for efficiency during optimization\n", "devset = devset\n", "\n", "print(f\"Dataset Stats: Train={len(trainset)}, Dev={len(devset)}\")\n", "\n", "# --- 3. Robust Signature & Module ---\n", "\n", "class HealthLiteracySignature(dspy.Signature):\n", " \"\"\"\n", " Judge the health literacy level of a generated medical summary.\n", " Identify if the language is suitable for a layperson (low) or requires medical expertise (proficient).\n", " \"\"\"\n", " summary_text: str = dspy.InputField(desc=\"The generated medical summary to be analyzed.\")\n", " reasoning: str = dspy.OutputField(desc=\"Analysis of jargon, acronyms, and sentence complexity.\")\n", " label: Literal[\"low_health_literacy\", \"intermediate_health_literacy\", \"proficient_health_literacy\"] = dspy.OutputField()\n", "\n", "class HealthLiteracyClassifier(dspy.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # ChainOfThought generates the reasoning field before the label\n", " self.predictor = dspy.ChainOfThought(HealthLiteracySignature)\n", "\n", " def forward(self, summary_text):\n", " return self.predictor(summary_text=summary_text)\n", "\n", "# --- 4. Metric and Optimization ---\n", "\n", "def health_literacy_metric(gold, pred, trace=None):\n", " if not pred.label: return False\n", " return gold.label.strip().lower() == pred.label.strip().lower()\n", "\n", "# BootstrapFewShotWithRandomSearch explores different demonstration combinations\n", "optimizer = BootstrapFewShotWithRandomSearch(\n", " metric=health_literacy_metric,\n", " max_bootstrapped_demos=3,\n", " num_candidate_programs=8, \n", " teacher_settings=dict(lm=openai_model)\n", ")\n", "\n", "# Compile using the local model, but with OpenAI generating the logic\n", "optimized_program = optimizer.compile(HealthLiteracyClassifier(), trainset=trainset)\n", "\n", "# --- 5. Evaluation & Saving ---\n", "\n", "evaluator = Evaluate(devset=devset, metric=health_literacy_metric, num_threads=1, display_progress=True)\n", "accuracy = evaluator(optimized_program)\n", "\n", "print(f\"\\nOptimization Complete.\")\n", "print(f\"Final Accuracy on Unseen Dev Set: {accuracy.score}%\")\n", "# print(f\"Final Accuracy on Unseen Dev Set: {accuracy * 100:.2f}%\")\n", "\n", "# Save the finalized prompt logic\n", "optimized_program.save(\"/home/mshahidul/readctrl/data/new_exp/optimized_health_classifier.json\")" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }