readctrl / code /text_classifier /text_classifier_dspy_load_and_infer_full.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
import argparse
import json
import os
from collections import Counter
from typing import Dict, List, Tuple
import dspy
from tqdm import tqdm
API_FILE = "/home/mshahidul/api_new.json"
DEFAULT_MODEL_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/model.json"
DEFAULT_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80.json"
DEFAULT_OUTPUT_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_accuracy.json"
DEFAULT_PREDICTIONS_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_predictions.json"
DEFAULT_CLEAN_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_clean200.json"
DEFAULT_REMOVED_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_removed21.json"
VALID_LABELS = {
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
}
LABEL_ORDER = {
"low_health_literacy": 0,
"intermediate_health_literacy": 1,
"proficient_health_literacy": 2,
}
class HealthLiteracySignature(dspy.Signature):
"""
Analyze the linguistic complexity, use of medical jargon, and sentence
structure of 'generated_text' to determine the health literacy level.
"""
generated_text = dspy.InputField(
desc="A version of the source text rewritten for a specific audience."
)
literacy_label = dspy.OutputField(
desc=(
"Classification: low_health_literacy (simple words, no jargon), "
"intermediate_health_literacy (moderate technicality), or "
"proficient_health_literacy (highly technical/original level)."
)
)
class HealthLiteracyClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classifier = dspy.ChainOfThought(HealthLiteracySignature)
def forward(self, generated_text):
return self.classifier(generated_text=generated_text)
def load_openai_key(api_file: str) -> str:
with open(api_file, "r") as f:
api_keys = json.load(f)
if "openai" not in api_keys:
raise KeyError(f"'openai' key is missing in {api_file}")
return api_keys["openai"]
def normalize_label(text: str) -> str:
return str(text or "").strip().lower()
def is_correct(gold_label: str, predicted_label: str) -> bool:
gold = normalize_label(gold_label)
pred = normalize_label(predicted_label)
return gold in pred
def extract_predicted_label(predicted_text: str) -> str:
pred = normalize_label(predicted_text)
matched = [label for label in VALID_LABELS if label in pred]
if len(matched) == 1:
return matched[0]
return ""
def misclassification_severity(gold_label: str, predicted_label: str) -> int:
gold = LABEL_ORDER.get(gold_label)
pred = LABEL_ORDER.get(predicted_label)
if gold is None or pred is None:
# Unknown/unparseable predictions are treated as worst.
return 3
return abs(gold - pred)
def load_full_examples(dataset_path: str):
with open(dataset_path, "r") as f:
raw_data = json.load(f)
examples = []
for idx, item in enumerate(raw_data):
label = item.get("label")
text = item.get("diff_label_texts")
if label in VALID_LABELS and text:
examples.append(
{
"index": idx,
"generated_text": text,
"gold_label": label,
"doc_id": item.get("doc_id"),
"raw_item": item,
}
)
if not examples:
raise ValueError("No valid labeled examples found in dataset.")
return examples
def choose_indices_to_remove(
predictions: List[Dict], remove_count: int
) -> Tuple[List[Dict], List[int]]:
def _rank_key(p: Dict):
return (
0 if not p["exact_correct"] else 1,
-p["severity"],
0 if not p["predicted_label"] else 1,
-len(normalize_label(p["raw_prediction_text"])),
p["index"],
)
label_sequence = sorted(VALID_LABELS, key=lambda x: LABEL_ORDER[x])
per_label_all = {label: [] for label in label_sequence}
per_label_mis = {label: [] for label in label_sequence}
for p in predictions:
label = p["gold_label"]
if label in per_label_all:
per_label_all[label].append(p)
if not p["exact_correct"]:
per_label_mis[label].append(p)
for label in label_sequence:
per_label_all[label].sort(key=_rank_key)
per_label_mis[label].sort(key=_rank_key)
# Balanced quota (approximately equal removals per label).
num_labels = len(label_sequence)
base_quota = remove_count // num_labels
remainder = remove_count % num_labels
quotas = {label: base_quota for label in label_sequence}
# Assign remainder to labels with more misclassified candidates first.
remainder_order = sorted(
label_sequence,
key=lambda label: (-len(per_label_mis[label]), LABEL_ORDER[label]),
)
for label in remainder_order[:remainder]:
quotas[label] += 1
removed = []
removed_indices_set = set()
# First pass: satisfy each label quota with misclassified items.
for label in label_sequence:
take = min(quotas[label], len(per_label_mis[label]))
for item in per_label_mis[label][:take]:
removed.append(item)
removed_indices_set.add(item["index"])
# Second pass: if some quotas could not be met, fill within those labels
# using next-worst remaining items (can include correctly classified).
for label in label_sequence:
needed = quotas[label] - sum(1 for x in removed if x["gold_label"] == label)
if needed <= 0:
continue
candidates = [
x for x in per_label_all[label] if x["index"] not in removed_indices_set
]
for item in candidates[:needed]:
removed.append(item)
removed_indices_set.add(item["index"])
# Final pass: if still short (edge cases), fill globally by worst rank.
if len(removed) < remove_count:
remaining_global = sorted(
(p for p in predictions if p["index"] not in removed_indices_set),
key=_rank_key,
)
need = remove_count - len(removed)
for item in remaining_global[:need]:
removed.append(item)
removed_indices_set.add(item["index"])
# Keep deterministic order in output by rank.
removed = sorted(removed, key=_rank_key)[:remove_count]
removed_indices = sorted(p["index"] for p in removed)
return removed, removed_indices
def run_inference(
model_path: str,
dataset_path: str,
output_path: str,
predictions_path: str,
clean_dataset_path: str,
removed_path: str,
target_clean_size: int,
):
openai_api_key = load_openai_key(API_FILE)
student_lm = dspy.LM(model="gpt-5-mini", api_key=openai_api_key)
dspy.configure(lm=student_lm)
classifier = HealthLiteracyClassifier()
classifier.load(model_path)
examples = load_full_examples(dataset_path)
total = len(examples)
if target_clean_size <= 0 or target_clean_size >= total:
raise ValueError(
f"target_clean_size must be between 1 and {total - 1}, got {target_clean_size}"
)
remove_count = total - target_clean_size
correct = 0
label_totals = Counter()
label_correct = Counter()
predictions = []
for idx, ex in enumerate(
tqdm(examples, desc="Classifying full dataset", unit="sample"), start=1
):
pred = classifier(generated_text=ex["generated_text"])
raw_pred_label = getattr(pred, "literacy_label", "")
pred_label = extract_predicted_label(raw_pred_label)
gold_label = ex["gold_label"]
exact_correct = pred_label == gold_label
lenient_correct = is_correct(gold_label, raw_pred_label)
severity = (
misclassification_severity(gold_label, pred_label) if not exact_correct else 0
)
label_totals[gold_label] += 1
if lenient_correct:
correct += 1
label_correct[gold_label] += 1
predictions.append(
{
"index": ex["index"],
"doc_id": ex["doc_id"],
"gold_label": gold_label,
"predicted_label": pred_label,
"raw_prediction_text": raw_pred_label,
"lenient_correct": lenient_correct,
"exact_correct": exact_correct,
"severity": severity,
"generated_text": ex["generated_text"],
}
)
if idx % 10 == 0 or idx == total:
tqdm.write(f"Processed {idx}/{total}")
accuracy = correct / total if total else 0.0
exact_accuracy = (
sum(1 for p in predictions if p["exact_correct"]) / total if total else 0.0
)
per_label_accuracy = {
label: (
(label_correct[label] / label_totals[label]) if label_totals[label] else 0.0
)
for label in sorted(VALID_LABELS)
}
removed_examples, removed_indices = choose_indices_to_remove(predictions, remove_count)
removed_index_set = set(removed_indices)
clean_dataset = [
p["raw_item"]
for p in examples
if p["index"] not in removed_index_set
]
removed_dataset = [
p["raw_item"]
for p in examples
if p["index"] in removed_index_set
]
report = {
"model_path": model_path,
"dataset_path": dataset_path,
"num_examples": total,
"num_correct": correct,
"lenient_accuracy": accuracy,
"exact_accuracy": exact_accuracy,
"per_label_accuracy": per_label_accuracy,
"target_clean_size": target_clean_size,
"removed_count": remove_count,
"clean_dataset_size": len(clean_dataset),
"removed_dataset_size": len(removed_dataset),
"removed_misclassified_count": sum(
1 for p in removed_examples if not p["exact_correct"]
),
"removed_per_label": dict(
Counter(p["gold_label"] for p in removed_examples)
),
}
for path in [
output_path,
predictions_path,
clean_dataset_path,
removed_path,
]:
output_dir = os.path.dirname(path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
with open(output_path, "w") as f:
json.dump(report, f, indent=2)
with open(predictions_path, "w") as f:
json.dump(predictions, f, indent=2)
with open(clean_dataset_path, "w") as f:
json.dump(clean_dataset, f, indent=2, ensure_ascii=False)
with open(removed_path, "w") as f:
json.dump(removed_dataset, f, indent=2, ensure_ascii=False)
print(json.dumps(report, indent=2))
print(f"Saved predictions to: {predictions_path}")
print(f"Saved clean dataset to: {clean_dataset_path}")
print(f"Saved removed examples to: {removed_path}")
print(f"Saved report to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Load a compiled DSPy classifier and evaluate on full dataset."
)
parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH)
parser.add_argument("--dataset-path", default=DEFAULT_DATASET_PATH)
parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH)
parser.add_argument("--predictions-path", default=DEFAULT_PREDICTIONS_PATH)
parser.add_argument("--clean-dataset-path", default=DEFAULT_CLEAN_DATASET_PATH)
parser.add_argument("--removed-path", default=DEFAULT_REMOVED_PATH)
parser.add_argument("--target-clean-size", type=int, default=200)
args = parser.parse_args()
run_inference(
model_path=args.model_path,
dataset_path=args.dataset_path,
output_path=args.output_path,
predictions_path=args.predictions_path,
clean_dataset_path=args.clean_dataset_path,
removed_path=args.removed_path,
target_clean_size=args.target_clean_size,
)
if __name__ == "__main__":
main()