import argparse import json import os from collections import Counter from typing import Dict, List, Tuple import dspy from tqdm import tqdm API_FILE = "/home/mshahidul/api_new.json" DEFAULT_MODEL_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/model.json" DEFAULT_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80.json" DEFAULT_OUTPUT_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_accuracy.json" DEFAULT_PREDICTIONS_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_predictions.json" DEFAULT_CLEAN_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_clean200.json" DEFAULT_REMOVED_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_removed21.json" VALID_LABELS = { "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", } LABEL_ORDER = { "low_health_literacy": 0, "intermediate_health_literacy": 1, "proficient_health_literacy": 2, } class HealthLiteracySignature(dspy.Signature): """ Analyze the linguistic complexity, use of medical jargon, and sentence structure of 'generated_text' to determine the health literacy level. """ generated_text = dspy.InputField( desc="A version of the source text rewritten for a specific audience." ) literacy_label = dspy.OutputField( desc=( "Classification: low_health_literacy (simple words, no jargon), " "intermediate_health_literacy (moderate technicality), or " "proficient_health_literacy (highly technical/original level)." ) ) class HealthLiteracyClassifier(dspy.Module): def __init__(self): super().__init__() self.classifier = dspy.ChainOfThought(HealthLiteracySignature) def forward(self, generated_text): return self.classifier(generated_text=generated_text) def load_openai_key(api_file: str) -> str: with open(api_file, "r") as f: api_keys = json.load(f) if "openai" not in api_keys: raise KeyError(f"'openai' key is missing in {api_file}") return api_keys["openai"] def normalize_label(text: str) -> str: return str(text or "").strip().lower() def is_correct(gold_label: str, predicted_label: str) -> bool: gold = normalize_label(gold_label) pred = normalize_label(predicted_label) return gold in pred def extract_predicted_label(predicted_text: str) -> str: pred = normalize_label(predicted_text) matched = [label for label in VALID_LABELS if label in pred] if len(matched) == 1: return matched[0] return "" def misclassification_severity(gold_label: str, predicted_label: str) -> int: gold = LABEL_ORDER.get(gold_label) pred = LABEL_ORDER.get(predicted_label) if gold is None or pred is None: # Unknown/unparseable predictions are treated as worst. return 3 return abs(gold - pred) def load_full_examples(dataset_path: str): with open(dataset_path, "r") as f: raw_data = json.load(f) examples = [] for idx, item in enumerate(raw_data): label = item.get("label") text = item.get("diff_label_texts") if label in VALID_LABELS and text: examples.append( { "index": idx, "generated_text": text, "gold_label": label, "doc_id": item.get("doc_id"), "raw_item": item, } ) if not examples: raise ValueError("No valid labeled examples found in dataset.") return examples def choose_indices_to_remove( predictions: List[Dict], remove_count: int ) -> Tuple[List[Dict], List[int]]: def _rank_key(p: Dict): return ( 0 if not p["exact_correct"] else 1, -p["severity"], 0 if not p["predicted_label"] else 1, -len(normalize_label(p["raw_prediction_text"])), p["index"], ) label_sequence = sorted(VALID_LABELS, key=lambda x: LABEL_ORDER[x]) per_label_all = {label: [] for label in label_sequence} per_label_mis = {label: [] for label in label_sequence} for p in predictions: label = p["gold_label"] if label in per_label_all: per_label_all[label].append(p) if not p["exact_correct"]: per_label_mis[label].append(p) for label in label_sequence: per_label_all[label].sort(key=_rank_key) per_label_mis[label].sort(key=_rank_key) # Balanced quota (approximately equal removals per label). num_labels = len(label_sequence) base_quota = remove_count // num_labels remainder = remove_count % num_labels quotas = {label: base_quota for label in label_sequence} # Assign remainder to labels with more misclassified candidates first. remainder_order = sorted( label_sequence, key=lambda label: (-len(per_label_mis[label]), LABEL_ORDER[label]), ) for label in remainder_order[:remainder]: quotas[label] += 1 removed = [] removed_indices_set = set() # First pass: satisfy each label quota with misclassified items. for label in label_sequence: take = min(quotas[label], len(per_label_mis[label])) for item in per_label_mis[label][:take]: removed.append(item) removed_indices_set.add(item["index"]) # Second pass: if some quotas could not be met, fill within those labels # using next-worst remaining items (can include correctly classified). for label in label_sequence: needed = quotas[label] - sum(1 for x in removed if x["gold_label"] == label) if needed <= 0: continue candidates = [ x for x in per_label_all[label] if x["index"] not in removed_indices_set ] for item in candidates[:needed]: removed.append(item) removed_indices_set.add(item["index"]) # Final pass: if still short (edge cases), fill globally by worst rank. if len(removed) < remove_count: remaining_global = sorted( (p for p in predictions if p["index"] not in removed_indices_set), key=_rank_key, ) need = remove_count - len(removed) for item in remaining_global[:need]: removed.append(item) removed_indices_set.add(item["index"]) # Keep deterministic order in output by rank. removed = sorted(removed, key=_rank_key)[:remove_count] removed_indices = sorted(p["index"] for p in removed) return removed, removed_indices def run_inference( model_path: str, dataset_path: str, output_path: str, predictions_path: str, clean_dataset_path: str, removed_path: str, target_clean_size: int, ): openai_api_key = load_openai_key(API_FILE) student_lm = dspy.LM(model="gpt-5-mini", api_key=openai_api_key) dspy.configure(lm=student_lm) classifier = HealthLiteracyClassifier() classifier.load(model_path) examples = load_full_examples(dataset_path) total = len(examples) if target_clean_size <= 0 or target_clean_size >= total: raise ValueError( f"target_clean_size must be between 1 and {total - 1}, got {target_clean_size}" ) remove_count = total - target_clean_size correct = 0 label_totals = Counter() label_correct = Counter() predictions = [] for idx, ex in enumerate( tqdm(examples, desc="Classifying full dataset", unit="sample"), start=1 ): pred = classifier(generated_text=ex["generated_text"]) raw_pred_label = getattr(pred, "literacy_label", "") pred_label = extract_predicted_label(raw_pred_label) gold_label = ex["gold_label"] exact_correct = pred_label == gold_label lenient_correct = is_correct(gold_label, raw_pred_label) severity = ( misclassification_severity(gold_label, pred_label) if not exact_correct else 0 ) label_totals[gold_label] += 1 if lenient_correct: correct += 1 label_correct[gold_label] += 1 predictions.append( { "index": ex["index"], "doc_id": ex["doc_id"], "gold_label": gold_label, "predicted_label": pred_label, "raw_prediction_text": raw_pred_label, "lenient_correct": lenient_correct, "exact_correct": exact_correct, "severity": severity, "generated_text": ex["generated_text"], } ) if idx % 10 == 0 or idx == total: tqdm.write(f"Processed {idx}/{total}") accuracy = correct / total if total else 0.0 exact_accuracy = ( sum(1 for p in predictions if p["exact_correct"]) / total if total else 0.0 ) per_label_accuracy = { label: ( (label_correct[label] / label_totals[label]) if label_totals[label] else 0.0 ) for label in sorted(VALID_LABELS) } removed_examples, removed_indices = choose_indices_to_remove(predictions, remove_count) removed_index_set = set(removed_indices) clean_dataset = [ p["raw_item"] for p in examples if p["index"] not in removed_index_set ] removed_dataset = [ p["raw_item"] for p in examples if p["index"] in removed_index_set ] report = { "model_path": model_path, "dataset_path": dataset_path, "num_examples": total, "num_correct": correct, "lenient_accuracy": accuracy, "exact_accuracy": exact_accuracy, "per_label_accuracy": per_label_accuracy, "target_clean_size": target_clean_size, "removed_count": remove_count, "clean_dataset_size": len(clean_dataset), "removed_dataset_size": len(removed_dataset), "removed_misclassified_count": sum( 1 for p in removed_examples if not p["exact_correct"] ), "removed_per_label": dict( Counter(p["gold_label"] for p in removed_examples) ), } for path in [ output_path, predictions_path, clean_dataset_path, removed_path, ]: output_dir = os.path.dirname(path) if output_dir: os.makedirs(output_dir, exist_ok=True) with open(output_path, "w") as f: json.dump(report, f, indent=2) with open(predictions_path, "w") as f: json.dump(predictions, f, indent=2) with open(clean_dataset_path, "w") as f: json.dump(clean_dataset, f, indent=2, ensure_ascii=False) with open(removed_path, "w") as f: json.dump(removed_dataset, f, indent=2, ensure_ascii=False) print(json.dumps(report, indent=2)) print(f"Saved predictions to: {predictions_path}") print(f"Saved clean dataset to: {clean_dataset_path}") print(f"Saved removed examples to: {removed_path}") print(f"Saved report to: {output_path}") def main(): parser = argparse.ArgumentParser( description="Load a compiled DSPy classifier and evaluate on full dataset." ) parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) parser.add_argument("--dataset-path", default=DEFAULT_DATASET_PATH) parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH) parser.add_argument("--predictions-path", default=DEFAULT_PREDICTIONS_PATH) parser.add_argument("--clean-dataset-path", default=DEFAULT_CLEAN_DATASET_PATH) parser.add_argument("--removed-path", default=DEFAULT_REMOVED_PATH) parser.add_argument("--target-clean-size", type=int, default=200) args = parser.parse_args() run_inference( model_path=args.model_path, dataset_path=args.dataset_path, output_path=args.output_path, predictions_path=args.predictions_path, clean_dataset_path=args.clean_dataset_path, removed_path=args.removed_path, target_clean_size=args.target_clean_size, ) if __name__ == "__main__": main()