File size: 4,953 Bytes
030876e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import dspy
import json
import os
import random
# Reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
# --- LLM Configuration (student only for inference) ---
# Student: "openai" = OpenAI API; "vllm" = local vLLM server
USE_OPENAI_AS_STUDENT =True
OPENAI_STUDENT_MODEL = os.environ.get("OPENAI_STUDENT_MODEL", "gpt-5")
api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
api_keys = json.load(f)
openai_api_key = api_keys["openai"]
# Student: Local vLLM (Deployment Model)
vllm_model = dspy.LM(
model="openai/dspy",
api_base="http://172.16.34.19:4090/v1",
api_key="EMPTY",
temperature=0.0,
)
# Student: OpenAI (optional)
openai_model_student = dspy.LM(
model=OPENAI_STUDENT_MODEL,
api_key=openai_api_key,
)
student_lm = openai_model_student if USE_OPENAI_AS_STUDENT else vllm_model
dspy.configure(lm=student_lm)
student_name = f"OpenAI ({OPENAI_STUDENT_MODEL})" if USE_OPENAI_AS_STUDENT else "vLLM (local)"
print(f"Student model (inference): {student_name}")
# --- Labels, signature, and helpers (mirrors training script) ---
LITERACY_LABELS = [
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
]
class HealthLiteracySignature(dspy.Signature):
"""
Analyze the linguistic complexity, use of medical jargon, and sentence
structure of 'generated_text' to determine the health literacy level.
Output exactly one of the three labels: low_health_literacy, intermediate_health_literacy, proficient_health_literacy.
"""
generated_text = dspy.InputField(
desc="A version of the source text rewritten for a specific audience."
)
literacy_label = dspy.OutputField(
desc=(
"Exactly one of: low_health_literacy (simple words, no jargon), "
"intermediate_health_literacy (moderate technicality), "
"proficient_health_literacy (highly technical/original level)."
)
)
class HealthLiteracyClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classifier = dspy.ChainOfThought(HealthLiteracySignature)
def forward(self, generated_text):
return self.classifier(generated_text=generated_text)
def _normalize_pred_to_label(pred_label: str) -> str:
"""Extract the first matching official label from model output (handles wordy answers)."""
pred_label = (pred_label or "").strip().lower()
for label in LITERACY_LABELS:
if label in pred_label:
return label
return pred_label
# --- Paths ---
BN_DIR = "/home/mshahidul/readctrl/code/text_classifier/bn"
DATA_PATH = os.path.join(BN_DIR, "testing_bn_full.json")
OUTPUT_PATH = os.path.join(BN_DIR, "testing_bn_clean_200.json")
def main():
# Initialize classifier (uses current student LM via dspy.configure above)
classifier = HealthLiteracyClassifier()
# Load full dataset
with open(DATA_PATH, "r", encoding="utf-8") as f:
raw_data = json.load(f)
print(f"Total input instances: {len(raw_data)}")
clean_examples = []
difficult_examples = []
for idx, item in enumerate(raw_data):
label = item.get("label")
if label not in LITERACY_LABELS:
# Skip unknown labels
continue
text = item.get("gen_text") or item.get("diff_label_texts", "")
if not text:
continue
pred = classifier(generated_text=text)
gold_label = str(label).strip().lower()
pred_raw = str(getattr(pred, "literacy_label", "") or "").strip().lower()
pred_normalized = _normalize_pred_to_label(pred_raw)
correct = bool(gold_label == pred_normalized or gold_label in pred_raw)
record = dict(item)
record["predicted_label"] = pred_normalized or pred_raw or "(empty)"
record["prediction_correct"] = correct
if correct:
clean_examples.append(record)
else:
difficult_examples.append(record)
print(f"Correctly predicted (easy) examples: {len(clean_examples)}")
print(f"Difficult examples (mismatch / unclear): {len(difficult_examples)}")
# Target: 200 examples total.
# Prefer clean/easy examples; if there are fewer than 200,
# fill the remaining slots with difficult examples.
target_n = 200
clean_200 = list(clean_examples[:target_n])
if len(clean_200) < target_n and difficult_examples:
remaining = target_n - len(clean_200)
extra = difficult_examples[:remaining]
clean_200.extend(extra)
print(
f"Saving {len(clean_200)} examples to: {OUTPUT_PATH} "
f"({sum(1 for r in clean_200 if r.get('prediction_correct'))} clean, "
f"{sum(1 for r in clean_200 if not r.get('prediction_correct'))} difficult)"
)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
json.dump(clean_200, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()
|