| | import argparse |
| | import json |
| | import os |
| | from collections import Counter |
| | from typing import Dict, List, Tuple |
| |
|
| | import dspy |
| | from tqdm import tqdm |
| |
|
| |
|
| | API_FILE = "/home/mshahidul/api_new.json" |
| | DEFAULT_MODEL_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/model.json" |
| | DEFAULT_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80.json" |
| | DEFAULT_OUTPUT_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_accuracy.json" |
| | DEFAULT_PREDICTIONS_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_predictions.json" |
| | DEFAULT_CLEAN_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_clean200.json" |
| | DEFAULT_REMOVED_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_removed21.json" |
| | VALID_LABELS = { |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | } |
| | LABEL_ORDER = { |
| | "low_health_literacy": 0, |
| | "intermediate_health_literacy": 1, |
| | "proficient_health_literacy": 2, |
| | } |
| |
|
| |
|
| | class HealthLiteracySignature(dspy.Signature): |
| | """ |
| | Analyze the linguistic complexity, use of medical jargon, and sentence |
| | structure of 'generated_text' to determine the health literacy level. |
| | """ |
| |
|
| | generated_text = dspy.InputField( |
| | desc="A version of the source text rewritten for a specific audience." |
| | ) |
| | literacy_label = dspy.OutputField( |
| | desc=( |
| | "Classification: low_health_literacy (simple words, no jargon), " |
| | "intermediate_health_literacy (moderate technicality), or " |
| | "proficient_health_literacy (highly technical/original level)." |
| | ) |
| | ) |
| |
|
| |
|
| | class HealthLiteracyClassifier(dspy.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
| |
|
| | def forward(self, generated_text): |
| | return self.classifier(generated_text=generated_text) |
| |
|
| |
|
| | def load_openai_key(api_file: str) -> str: |
| | with open(api_file, "r") as f: |
| | api_keys = json.load(f) |
| | if "openai" not in api_keys: |
| | raise KeyError(f"'openai' key is missing in {api_file}") |
| | return api_keys["openai"] |
| |
|
| |
|
| | def normalize_label(text: str) -> str: |
| | return str(text or "").strip().lower() |
| |
|
| |
|
| | def is_correct(gold_label: str, predicted_label: str) -> bool: |
| | gold = normalize_label(gold_label) |
| | pred = normalize_label(predicted_label) |
| | return gold in pred |
| |
|
| |
|
| | def extract_predicted_label(predicted_text: str) -> str: |
| | pred = normalize_label(predicted_text) |
| | matched = [label for label in VALID_LABELS if label in pred] |
| | if len(matched) == 1: |
| | return matched[0] |
| | return "" |
| |
|
| |
|
| | def misclassification_severity(gold_label: str, predicted_label: str) -> int: |
| | gold = LABEL_ORDER.get(gold_label) |
| | pred = LABEL_ORDER.get(predicted_label) |
| | if gold is None or pred is None: |
| | |
| | return 3 |
| | return abs(gold - pred) |
| |
|
| |
|
| | def load_full_examples(dataset_path: str): |
| | with open(dataset_path, "r") as f: |
| | raw_data = json.load(f) |
| |
|
| | examples = [] |
| | for idx, item in enumerate(raw_data): |
| | label = item.get("label") |
| | text = item.get("diff_label_texts") |
| | if label in VALID_LABELS and text: |
| | examples.append( |
| | { |
| | "index": idx, |
| | "generated_text": text, |
| | "gold_label": label, |
| | "doc_id": item.get("doc_id"), |
| | "raw_item": item, |
| | } |
| | ) |
| | if not examples: |
| | raise ValueError("No valid labeled examples found in dataset.") |
| | return examples |
| |
|
| |
|
| | def choose_indices_to_remove( |
| | predictions: List[Dict], remove_count: int |
| | ) -> Tuple[List[Dict], List[int]]: |
| | def _rank_key(p: Dict): |
| | return ( |
| | 0 if not p["exact_correct"] else 1, |
| | -p["severity"], |
| | 0 if not p["predicted_label"] else 1, |
| | -len(normalize_label(p["raw_prediction_text"])), |
| | p["index"], |
| | ) |
| |
|
| | label_sequence = sorted(VALID_LABELS, key=lambda x: LABEL_ORDER[x]) |
| | per_label_all = {label: [] for label in label_sequence} |
| | per_label_mis = {label: [] for label in label_sequence} |
| | for p in predictions: |
| | label = p["gold_label"] |
| | if label in per_label_all: |
| | per_label_all[label].append(p) |
| | if not p["exact_correct"]: |
| | per_label_mis[label].append(p) |
| |
|
| | for label in label_sequence: |
| | per_label_all[label].sort(key=_rank_key) |
| | per_label_mis[label].sort(key=_rank_key) |
| |
|
| | |
| | num_labels = len(label_sequence) |
| | base_quota = remove_count // num_labels |
| | remainder = remove_count % num_labels |
| | quotas = {label: base_quota for label in label_sequence} |
| |
|
| | |
| | remainder_order = sorted( |
| | label_sequence, |
| | key=lambda label: (-len(per_label_mis[label]), LABEL_ORDER[label]), |
| | ) |
| | for label in remainder_order[:remainder]: |
| | quotas[label] += 1 |
| |
|
| | removed = [] |
| | removed_indices_set = set() |
| |
|
| | |
| | for label in label_sequence: |
| | take = min(quotas[label], len(per_label_mis[label])) |
| | for item in per_label_mis[label][:take]: |
| | removed.append(item) |
| | removed_indices_set.add(item["index"]) |
| |
|
| | |
| | |
| | for label in label_sequence: |
| | needed = quotas[label] - sum(1 for x in removed if x["gold_label"] == label) |
| | if needed <= 0: |
| | continue |
| | candidates = [ |
| | x for x in per_label_all[label] if x["index"] not in removed_indices_set |
| | ] |
| | for item in candidates[:needed]: |
| | removed.append(item) |
| | removed_indices_set.add(item["index"]) |
| |
|
| | |
| | if len(removed) < remove_count: |
| | remaining_global = sorted( |
| | (p for p in predictions if p["index"] not in removed_indices_set), |
| | key=_rank_key, |
| | ) |
| | need = remove_count - len(removed) |
| | for item in remaining_global[:need]: |
| | removed.append(item) |
| | removed_indices_set.add(item["index"]) |
| |
|
| | |
| | removed = sorted(removed, key=_rank_key)[:remove_count] |
| | removed_indices = sorted(p["index"] for p in removed) |
| | return removed, removed_indices |
| |
|
| |
|
| | def run_inference( |
| | model_path: str, |
| | dataset_path: str, |
| | output_path: str, |
| | predictions_path: str, |
| | clean_dataset_path: str, |
| | removed_path: str, |
| | target_clean_size: int, |
| | ): |
| | openai_api_key = load_openai_key(API_FILE) |
| | student_lm = dspy.LM(model="gpt-5-mini", api_key=openai_api_key) |
| | dspy.configure(lm=student_lm) |
| |
|
| | classifier = HealthLiteracyClassifier() |
| | classifier.load(model_path) |
| |
|
| | examples = load_full_examples(dataset_path) |
| | total = len(examples) |
| | if target_clean_size <= 0 or target_clean_size >= total: |
| | raise ValueError( |
| | f"target_clean_size must be between 1 and {total - 1}, got {target_clean_size}" |
| | ) |
| |
|
| | remove_count = total - target_clean_size |
| | correct = 0 |
| | label_totals = Counter() |
| | label_correct = Counter() |
| | predictions = [] |
| |
|
| | for idx, ex in enumerate( |
| | tqdm(examples, desc="Classifying full dataset", unit="sample"), start=1 |
| | ): |
| | pred = classifier(generated_text=ex["generated_text"]) |
| | raw_pred_label = getattr(pred, "literacy_label", "") |
| | pred_label = extract_predicted_label(raw_pred_label) |
| | gold_label = ex["gold_label"] |
| | exact_correct = pred_label == gold_label |
| | lenient_correct = is_correct(gold_label, raw_pred_label) |
| | severity = ( |
| | misclassification_severity(gold_label, pred_label) if not exact_correct else 0 |
| | ) |
| |
|
| | label_totals[gold_label] += 1 |
| | if lenient_correct: |
| | correct += 1 |
| | label_correct[gold_label] += 1 |
| |
|
| | predictions.append( |
| | { |
| | "index": ex["index"], |
| | "doc_id": ex["doc_id"], |
| | "gold_label": gold_label, |
| | "predicted_label": pred_label, |
| | "raw_prediction_text": raw_pred_label, |
| | "lenient_correct": lenient_correct, |
| | "exact_correct": exact_correct, |
| | "severity": severity, |
| | "generated_text": ex["generated_text"], |
| | } |
| | ) |
| |
|
| | if idx % 10 == 0 or idx == total: |
| | tqdm.write(f"Processed {idx}/{total}") |
| |
|
| | accuracy = correct / total if total else 0.0 |
| | exact_accuracy = ( |
| | sum(1 for p in predictions if p["exact_correct"]) / total if total else 0.0 |
| | ) |
| | per_label_accuracy = { |
| | label: ( |
| | (label_correct[label] / label_totals[label]) if label_totals[label] else 0.0 |
| | ) |
| | for label in sorted(VALID_LABELS) |
| | } |
| | removed_examples, removed_indices = choose_indices_to_remove(predictions, remove_count) |
| | removed_index_set = set(removed_indices) |
| | clean_dataset = [ |
| | p["raw_item"] |
| | for p in examples |
| | if p["index"] not in removed_index_set |
| | ] |
| | removed_dataset = [ |
| | p["raw_item"] |
| | for p in examples |
| | if p["index"] in removed_index_set |
| | ] |
| |
|
| | report = { |
| | "model_path": model_path, |
| | "dataset_path": dataset_path, |
| | "num_examples": total, |
| | "num_correct": correct, |
| | "lenient_accuracy": accuracy, |
| | "exact_accuracy": exact_accuracy, |
| | "per_label_accuracy": per_label_accuracy, |
| | "target_clean_size": target_clean_size, |
| | "removed_count": remove_count, |
| | "clean_dataset_size": len(clean_dataset), |
| | "removed_dataset_size": len(removed_dataset), |
| | "removed_misclassified_count": sum( |
| | 1 for p in removed_examples if not p["exact_correct"] |
| | ), |
| | "removed_per_label": dict( |
| | Counter(p["gold_label"] for p in removed_examples) |
| | ), |
| | } |
| |
|
| | for path in [ |
| | output_path, |
| | predictions_path, |
| | clean_dataset_path, |
| | removed_path, |
| | ]: |
| | output_dir = os.path.dirname(path) |
| | if output_dir: |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | with open(output_path, "w") as f: |
| | json.dump(report, f, indent=2) |
| | with open(predictions_path, "w") as f: |
| | json.dump(predictions, f, indent=2) |
| | with open(clean_dataset_path, "w") as f: |
| | json.dump(clean_dataset, f, indent=2, ensure_ascii=False) |
| | with open(removed_path, "w") as f: |
| | json.dump(removed_dataset, f, indent=2, ensure_ascii=False) |
| |
|
| | print(json.dumps(report, indent=2)) |
| | print(f"Saved predictions to: {predictions_path}") |
| | print(f"Saved clean dataset to: {clean_dataset_path}") |
| | print(f"Saved removed examples to: {removed_path}") |
| | print(f"Saved report to: {output_path}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Load a compiled DSPy classifier and evaluate on full dataset." |
| | ) |
| | parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| | parser.add_argument("--dataset-path", default=DEFAULT_DATASET_PATH) |
| | parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH) |
| | parser.add_argument("--predictions-path", default=DEFAULT_PREDICTIONS_PATH) |
| | parser.add_argument("--clean-dataset-path", default=DEFAULT_CLEAN_DATASET_PATH) |
| | parser.add_argument("--removed-path", default=DEFAULT_REMOVED_PATH) |
| | parser.add_argument("--target-clean-size", type=int, default=200) |
| | args = parser.parse_args() |
| |
|
| | run_inference( |
| | model_path=args.model_path, |
| | dataset_path=args.dataset_path, |
| | output_path=args.output_path, |
| | predictions_path=args.predictions_path, |
| | clean_dataset_path=args.clean_dataset_path, |
| | removed_path=args.removed_path, |
| | target_clean_size=args.target_clean_size, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|