| import dspy |
| import json |
| import os |
| import random |
|
|
|
|
| |
| RANDOM_SEED = 42 |
| random.seed(RANDOM_SEED) |
|
|
|
|
| |
| |
| USE_OPENAI_AS_STUDENT =True |
| OPENAI_STUDENT_MODEL = os.environ.get("OPENAI_STUDENT_MODEL", "gpt-5") |
|
|
| api_file = "/home/mshahidul/api_new.json" |
| with open(api_file, "r") as f: |
| api_keys = json.load(f) |
| openai_api_key = api_keys["openai"] |
|
|
| |
| vllm_model = dspy.LM( |
| model="openai/dspy", |
| api_base="http://172.16.34.19:4090/v1", |
| api_key="EMPTY", |
| temperature=0.0, |
| ) |
|
|
| |
| openai_model_student = dspy.LM( |
| model=OPENAI_STUDENT_MODEL, |
| api_key=openai_api_key, |
| ) |
|
|
| student_lm = openai_model_student if USE_OPENAI_AS_STUDENT else vllm_model |
| dspy.configure(lm=student_lm) |
|
|
| student_name = f"OpenAI ({OPENAI_STUDENT_MODEL})" if USE_OPENAI_AS_STUDENT else "vLLM (local)" |
| print(f"Student model (inference): {student_name}") |
|
|
|
|
| |
| LITERACY_LABELS = [ |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| ] |
|
|
|
|
| class HealthLiteracySignature(dspy.Signature): |
| """ |
| Analyze the linguistic complexity, use of medical jargon, and sentence |
| structure of 'generated_text' to determine the health literacy level. |
| Output exactly one of the three labels: low_health_literacy, intermediate_health_literacy, proficient_health_literacy. |
| """ |
|
|
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
|
|
| literacy_label = dspy.OutputField( |
| desc=( |
| "Exactly one of: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| def _normalize_pred_to_label(pred_label: str) -> str: |
| """Extract the first matching official label from model output (handles wordy answers).""" |
| pred_label = (pred_label or "").strip().lower() |
| for label in LITERACY_LABELS: |
| if label in pred_label: |
| return label |
| return pred_label |
|
|
|
|
| |
| BN_DIR = "/home/mshahidul/readctrl/code/text_classifier/bn" |
| DATA_PATH = os.path.join(BN_DIR, "testing_bn_full.json") |
| OUTPUT_PATH = os.path.join(BN_DIR, "testing_bn_clean_200.json") |
|
|
|
|
| def main(): |
| |
| classifier = HealthLiteracyClassifier() |
|
|
| |
| with open(DATA_PATH, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| print(f"Total input instances: {len(raw_data)}") |
|
|
| clean_examples = [] |
| difficult_examples = [] |
|
|
| for idx, item in enumerate(raw_data): |
| label = item.get("label") |
| if label not in LITERACY_LABELS: |
| |
| continue |
|
|
| text = item.get("gen_text") or item.get("diff_label_texts", "") |
| if not text: |
| continue |
|
|
| pred = classifier(generated_text=text) |
| gold_label = str(label).strip().lower() |
| pred_raw = str(getattr(pred, "literacy_label", "") or "").strip().lower() |
| pred_normalized = _normalize_pred_to_label(pred_raw) |
|
|
| correct = bool(gold_label == pred_normalized or gold_label in pred_raw) |
|
|
| record = dict(item) |
| record["predicted_label"] = pred_normalized or pred_raw or "(empty)" |
| record["prediction_correct"] = correct |
|
|
| if correct: |
| clean_examples.append(record) |
| else: |
| difficult_examples.append(record) |
|
|
| print(f"Correctly predicted (easy) examples: {len(clean_examples)}") |
| print(f"Difficult examples (mismatch / unclear): {len(difficult_examples)}") |
|
|
| |
| |
| |
| target_n = 200 |
| clean_200 = list(clean_examples[:target_n]) |
| if len(clean_200) < target_n and difficult_examples: |
| remaining = target_n - len(clean_200) |
| extra = difficult_examples[:remaining] |
| clean_200.extend(extra) |
|
|
| print( |
| f"Saving {len(clean_200)} examples to: {OUTPUT_PATH} " |
| f"({sum(1 for r in clean_200 if r.get('prediction_correct'))} clean, " |
| f"{sum(1 for r in clean_200 if not r.get('prediction_correct'))} difficult)" |
| ) |
|
|
| with open(OUTPUT_PATH, "w", encoding="utf-8") as f: |
| json.dump(clean_200, f, ensure_ascii=False, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|