File size: 4,953 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import dspy
import json
import os
import random


# Reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)


# --- LLM Configuration (student only for inference) ---
# Student: "openai" = OpenAI API; "vllm" = local vLLM server
USE_OPENAI_AS_STUDENT =True
OPENAI_STUDENT_MODEL = os.environ.get("OPENAI_STUDENT_MODEL", "gpt-5")

api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
    api_keys = json.load(f)
openai_api_key = api_keys["openai"]

# Student: Local vLLM (Deployment Model)
vllm_model = dspy.LM(
    model="openai/dspy",
    api_base="http://172.16.34.19:4090/v1",
    api_key="EMPTY",
    temperature=0.0,
)

# Student: OpenAI (optional)
openai_model_student = dspy.LM(
    model=OPENAI_STUDENT_MODEL,
    api_key=openai_api_key,
)

student_lm = openai_model_student if USE_OPENAI_AS_STUDENT else vllm_model
dspy.configure(lm=student_lm)

student_name = f"OpenAI ({OPENAI_STUDENT_MODEL})" if USE_OPENAI_AS_STUDENT else "vLLM (local)"
print(f"Student model (inference): {student_name}")


# --- Labels, signature, and helpers (mirrors training script) ---
LITERACY_LABELS = [
    "low_health_literacy",
    "intermediate_health_literacy",
    "proficient_health_literacy",
]


class HealthLiteracySignature(dspy.Signature):
    """
    Analyze the linguistic complexity, use of medical jargon, and sentence
    structure of 'generated_text' to determine the health literacy level.
    Output exactly one of the three labels: low_health_literacy, intermediate_health_literacy, proficient_health_literacy.
    """

    generated_text = dspy.InputField(
        desc="A version of the source text rewritten for a specific audience."
    )

    literacy_label = dspy.OutputField(
        desc=(
            "Exactly one of: low_health_literacy (simple words, no jargon), "
            "intermediate_health_literacy (moderate technicality), "
            "proficient_health_literacy (highly technical/original level)."
        )
    )


class HealthLiteracyClassifier(dspy.Module):
    def __init__(self):
        super().__init__()
        self.classifier = dspy.ChainOfThought(HealthLiteracySignature)

    def forward(self, generated_text):
        return self.classifier(generated_text=generated_text)


def _normalize_pred_to_label(pred_label: str) -> str:
    """Extract the first matching official label from model output (handles wordy answers)."""
    pred_label = (pred_label or "").strip().lower()
    for label in LITERACY_LABELS:
        if label in pred_label:
            return label
    return pred_label


# --- Paths ---
BN_DIR = "/home/mshahidul/readctrl/code/text_classifier/bn"
DATA_PATH = os.path.join(BN_DIR, "testing_bn_full.json")
OUTPUT_PATH = os.path.join(BN_DIR, "testing_bn_clean_200.json")


def main():
    # Initialize classifier (uses current student LM via dspy.configure above)
    classifier = HealthLiteracyClassifier()

    # Load full dataset
    with open(DATA_PATH, "r", encoding="utf-8") as f:
        raw_data = json.load(f)

    print(f"Total input instances: {len(raw_data)}")

    clean_examples = []
    difficult_examples = []

    for idx, item in enumerate(raw_data):
        label = item.get("label")
        if label not in LITERACY_LABELS:
            # Skip unknown labels
            continue

        text = item.get("gen_text") or item.get("diff_label_texts", "")
        if not text:
            continue

        pred = classifier(generated_text=text)
        gold_label = str(label).strip().lower()
        pred_raw = str(getattr(pred, "literacy_label", "") or "").strip().lower()
        pred_normalized = _normalize_pred_to_label(pred_raw)

        correct = bool(gold_label == pred_normalized or gold_label in pred_raw)

        record = dict(item)
        record["predicted_label"] = pred_normalized or pred_raw or "(empty)"
        record["prediction_correct"] = correct

        if correct:
            clean_examples.append(record)
        else:
            difficult_examples.append(record)

    print(f"Correctly predicted (easy) examples: {len(clean_examples)}")
    print(f"Difficult examples (mismatch / unclear): {len(difficult_examples)}")

    # Target: 200 examples total.
    # Prefer clean/easy examples; if there are fewer than 200,
    # fill the remaining slots with difficult examples.
    target_n = 200
    clean_200 = list(clean_examples[:target_n])
    if len(clean_200) < target_n and difficult_examples:
        remaining = target_n - len(clean_200)
        extra = difficult_examples[:remaining]
        clean_200.extend(extra)

    print(
        f"Saving {len(clean_200)} examples to: {OUTPUT_PATH} "
        f"({sum(1 for r in clean_200 if r.get('prediction_correct'))} clean, "
        f"{sum(1 for r in clean_200 if not r.get('prediction_correct'))} difficult)"
    )

    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        json.dump(clean_200, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()