readCtrl_lambda / code /text_classifier /bn /inference_clean_200.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import dspy
import json
import os
import random
# Reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
# --- LLM Configuration (student only for inference) ---
# Student: "openai" = OpenAI API; "vllm" = local vLLM server
USE_OPENAI_AS_STUDENT =True
OPENAI_STUDENT_MODEL = os.environ.get("OPENAI_STUDENT_MODEL", "gpt-5")
api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
api_keys = json.load(f)
openai_api_key = api_keys["openai"]
# Student: Local vLLM (Deployment Model)
vllm_model = dspy.LM(
model="openai/dspy",
api_base="http://172.16.34.19:4090/v1",
api_key="EMPTY",
temperature=0.0,
)
# Student: OpenAI (optional)
openai_model_student = dspy.LM(
model=OPENAI_STUDENT_MODEL,
api_key=openai_api_key,
)
student_lm = openai_model_student if USE_OPENAI_AS_STUDENT else vllm_model
dspy.configure(lm=student_lm)
student_name = f"OpenAI ({OPENAI_STUDENT_MODEL})" if USE_OPENAI_AS_STUDENT else "vLLM (local)"
print(f"Student model (inference): {student_name}")
# --- Labels, signature, and helpers (mirrors training script) ---
LITERACY_LABELS = [
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
]
class HealthLiteracySignature(dspy.Signature):
"""
Analyze the linguistic complexity, use of medical jargon, and sentence
structure of 'generated_text' to determine the health literacy level.
Output exactly one of the three labels: low_health_literacy, intermediate_health_literacy, proficient_health_literacy.
"""
generated_text = dspy.InputField(
desc="A version of the source text rewritten for a specific audience."
)
literacy_label = dspy.OutputField(
desc=(
"Exactly one of: low_health_literacy (simple words, no jargon), "
"intermediate_health_literacy (moderate technicality), "
"proficient_health_literacy (highly technical/original level)."
)
)
class HealthLiteracyClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classifier = dspy.ChainOfThought(HealthLiteracySignature)
def forward(self, generated_text):
return self.classifier(generated_text=generated_text)
def _normalize_pred_to_label(pred_label: str) -> str:
"""Extract the first matching official label from model output (handles wordy answers)."""
pred_label = (pred_label or "").strip().lower()
for label in LITERACY_LABELS:
if label in pred_label:
return label
return pred_label
# --- Paths ---
BN_DIR = "/home/mshahidul/readctrl/code/text_classifier/bn"
DATA_PATH = os.path.join(BN_DIR, "testing_bn_full.json")
OUTPUT_PATH = os.path.join(BN_DIR, "testing_bn_clean_200.json")
def main():
# Initialize classifier (uses current student LM via dspy.configure above)
classifier = HealthLiteracyClassifier()
# Load full dataset
with open(DATA_PATH, "r", encoding="utf-8") as f:
raw_data = json.load(f)
print(f"Total input instances: {len(raw_data)}")
clean_examples = []
difficult_examples = []
for idx, item in enumerate(raw_data):
label = item.get("label")
if label not in LITERACY_LABELS:
# Skip unknown labels
continue
text = item.get("gen_text") or item.get("diff_label_texts", "")
if not text:
continue
pred = classifier(generated_text=text)
gold_label = str(label).strip().lower()
pred_raw = str(getattr(pred, "literacy_label", "") or "").strip().lower()
pred_normalized = _normalize_pred_to_label(pred_raw)
correct = bool(gold_label == pred_normalized or gold_label in pred_raw)
record = dict(item)
record["predicted_label"] = pred_normalized or pred_raw or "(empty)"
record["prediction_correct"] = correct
if correct:
clean_examples.append(record)
else:
difficult_examples.append(record)
print(f"Correctly predicted (easy) examples: {len(clean_examples)}")
print(f"Difficult examples (mismatch / unclear): {len(difficult_examples)}")
# Target: 200 examples total.
# Prefer clean/easy examples; if there are fewer than 200,
# fill the remaining slots with difficult examples.
target_n = 200
clean_200 = list(clean_examples[:target_n])
if len(clean_200) < target_n and difficult_examples:
remaining = target_n - len(clean_200)
extra = difficult_examples[:remaining]
clean_200.extend(extra)
print(
f"Saving {len(clean_200)} examples to: {OUTPUT_PATH} "
f"({sum(1 for r in clean_200 if r.get('prediction_correct'))} clean, "
f"{sum(1 for r in clean_200 if not r.get('prediction_correct'))} difficult)"
)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
json.dump(clean_200, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()