import dspy import json import os import random # Reproducibility RANDOM_SEED = 42 random.seed(RANDOM_SEED) # --- LLM Configuration (student only for inference) --- # Student: "openai" = OpenAI API; "vllm" = local vLLM server USE_OPENAI_AS_STUDENT =True OPENAI_STUDENT_MODEL = os.environ.get("OPENAI_STUDENT_MODEL", "gpt-5") api_file = "/home/mshahidul/api_new.json" with open(api_file, "r") as f: api_keys = json.load(f) openai_api_key = api_keys["openai"] # Student: Local vLLM (Deployment Model) vllm_model = dspy.LM( model="openai/dspy", api_base="http://172.16.34.19:4090/v1", api_key="EMPTY", temperature=0.0, ) # Student: OpenAI (optional) openai_model_student = dspy.LM( model=OPENAI_STUDENT_MODEL, api_key=openai_api_key, ) student_lm = openai_model_student if USE_OPENAI_AS_STUDENT else vllm_model dspy.configure(lm=student_lm) student_name = f"OpenAI ({OPENAI_STUDENT_MODEL})" if USE_OPENAI_AS_STUDENT else "vLLM (local)" print(f"Student model (inference): {student_name}") # --- Labels, signature, and helpers (mirrors training script) --- LITERACY_LABELS = [ "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", ] class HealthLiteracySignature(dspy.Signature): """ Analyze the linguistic complexity, use of medical jargon, and sentence structure of 'generated_text' to determine the health literacy level. Output exactly one of the three labels: low_health_literacy, intermediate_health_literacy, proficient_health_literacy. """ generated_text = dspy.InputField( desc="A version of the source text rewritten for a specific audience." ) literacy_label = dspy.OutputField( desc=( "Exactly one of: low_health_literacy (simple words, no jargon), " "intermediate_health_literacy (moderate technicality), " "proficient_health_literacy (highly technical/original level)." ) ) class HealthLiteracyClassifier(dspy.Module): def __init__(self): super().__init__() self.classifier = dspy.ChainOfThought(HealthLiteracySignature) def forward(self, generated_text): return self.classifier(generated_text=generated_text) def _normalize_pred_to_label(pred_label: str) -> str: """Extract the first matching official label from model output (handles wordy answers).""" pred_label = (pred_label or "").strip().lower() for label in LITERACY_LABELS: if label in pred_label: return label return pred_label # --- Paths --- BN_DIR = "/home/mshahidul/readctrl/code/text_classifier/bn" DATA_PATH = os.path.join(BN_DIR, "testing_bn_full.json") OUTPUT_PATH = os.path.join(BN_DIR, "testing_bn_clean_200.json") def main(): # Initialize classifier (uses current student LM via dspy.configure above) classifier = HealthLiteracyClassifier() # Load full dataset with open(DATA_PATH, "r", encoding="utf-8") as f: raw_data = json.load(f) print(f"Total input instances: {len(raw_data)}") clean_examples = [] difficult_examples = [] for idx, item in enumerate(raw_data): label = item.get("label") if label not in LITERACY_LABELS: # Skip unknown labels continue text = item.get("gen_text") or item.get("diff_label_texts", "") if not text: continue pred = classifier(generated_text=text) gold_label = str(label).strip().lower() pred_raw = str(getattr(pred, "literacy_label", "") or "").strip().lower() pred_normalized = _normalize_pred_to_label(pred_raw) correct = bool(gold_label == pred_normalized or gold_label in pred_raw) record = dict(item) record["predicted_label"] = pred_normalized or pred_raw or "(empty)" record["prediction_correct"] = correct if correct: clean_examples.append(record) else: difficult_examples.append(record) print(f"Correctly predicted (easy) examples: {len(clean_examples)}") print(f"Difficult examples (mismatch / unclear): {len(difficult_examples)}") # Target: 200 examples total. # Prefer clean/easy examples; if there are fewer than 200, # fill the remaining slots with difficult examples. target_n = 200 clean_200 = list(clean_examples[:target_n]) if len(clean_200) < target_n and difficult_examples: remaining = target_n - len(clean_200) extra = difficult_examples[:remaining] clean_200.extend(extra) print( f"Saving {len(clean_200)} examples to: {OUTPUT_PATH} " f"({sum(1 for r in clean_200 if r.get('prediction_correct'))} clean, " f"{sum(1 for r in clean_200 if not r.get('prediction_correct'))} difficult)" ) with open(OUTPUT_PATH, "w", encoding="utf-8") as f: json.dump(clean_200, f, ensure_ascii=False, indent=2) if __name__ == "__main__": main()