import argparse import json import os import traceback import urllib.error import urllib.request import dspy from dspy.evaluate import Evaluate DEFAULT_API_BASE = "http://172.16.34.22:8040/v1" DEFAULT_MODEL_PATH = ( "/home/mshahidul/readctrl/code/text_classifier/dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" ) DEFAULT_TEST_PATH = "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json" DEFAULT_OUTPUT_PATH = ( "/home/mshahidul/readctrl/code/text_classifier/accuracy/" "vllm-llama-3.1-8b-awq-int4_teacher-gpt5_v1_clean200_eval.json" ) class HealthLiteracySignature(dspy.Signature): generated_text = dspy.InputField( desc="A version of the source text rewritten for a specific audience." ) literacy_label = dspy.OutputField( desc=( "Classification: low_health_literacy (simple words, no jargon), " "intermediate_health_literacy (moderate technicality), or " "proficient_health_literacy (highly technical/original level)." ) ) class HealthLiteracyClassifier(dspy.Module): def __init__(self): super().__init__() self.classifier = dspy.ChainOfThought(HealthLiteracySignature) def forward(self, generated_text): return self.classifier(generated_text=generated_text) def parse_args(): parser = argparse.ArgumentParser( description="Load a saved DSPy model and evaluate on test set." ) parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) parser.add_argument("--test-path", default=DEFAULT_TEST_PATH) parser.add_argument( "--api-base", default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE), ) parser.add_argument("--num-threads", type=int, default=1) parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH) parser.add_argument( "--provide-traceback", action="store_true", help="Print full traceback if runtime error happens.", ) return parser.parse_args() def check_api_base(api_base): models_url = api_base.rstrip("/") + "/models" req = urllib.request.Request(models_url, method="GET") try: with urllib.request.urlopen(req, timeout=5) as resp: if resp.status >= 400: raise RuntimeError( f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" ) except urllib.error.URLError as exc: raise ConnectionError( "Cannot reach OpenAI-compatible endpoint. " f"api_base={api_base}. " "Start your vLLM server or pass correct --api-base." ) from exc def load_testset(path): examples = [] if path.endswith(".jsonl"): with open(path, "r") as f: for line in f: if not line.strip(): continue record = json.loads(line) example = dspy.Example( generated_text=record["generated_text"], literacy_label=record["literacy_label"], ).with_inputs("generated_text") examples.append(example) else: with open(path, "r") as f: records = json.load(f) for record in records: text = record.get("generated_text", record.get("diff_label_texts")) label = record.get("literacy_label", record.get("label")) if not text or not label: continue example = dspy.Example( generated_text=text, literacy_label=label, ).with_inputs("generated_text") examples.append(example) return examples def health_literacy_metric(gold, pred, trace=None): if not pred or not hasattr(pred, "literacy_label"): return False gold_label = str(gold.literacy_label).strip().lower() pred_label = str(pred.literacy_label).strip().lower() return gold_label in pred_label def load_compiled_classifier(path): if hasattr(dspy, "load"): try: return dspy.load(path) except Exception as exc: print( f"[warning] dspy.load failed ({type(exc).__name__}); " "trying module.load(...)" ) classifier = HealthLiteracyClassifier() try: classifier.load(path) except Exception as exc: raise RuntimeError(f"Failed to load compiled model from {path}") from exc return classifier def main(): args = parse_args() if not os.path.exists(args.model_path): raise FileNotFoundError(f"Model file not found: {args.model_path}") if not os.path.exists(args.test_path): raise FileNotFoundError(f"Test file not found: {args.test_path}") try: check_api_base(args.api_base) lm = dspy.LM( model="openai/dspy", api_base=args.api_base, api_key="EMPTY", temperature=0.0, ) dspy.configure(lm=lm) testset = load_testset(args.test_path) compiled_classifier = load_compiled_classifier(args.model_path) evaluator = Evaluate( devset=testset, metric=health_literacy_metric, num_threads=args.num_threads, display_progress=True, ) evaluation_result = evaluator(compiled_classifier) accuracy_score = ( float(evaluation_result.score) if hasattr(evaluation_result, "score") else float(evaluation_result) ) output_data = { "model_path": args.model_path, "test_path": args.test_path, "accuracy_score": accuracy_score, "num_results": len(getattr(evaluation_result, "results", []) or []), } os.makedirs(os.path.dirname(args.output_path), exist_ok=True) with open(args.output_path, "w") as f: json.dump(output_data, f, indent=2) print(evaluation_result) print(json.dumps(output_data, indent=2)) except Exception as exc: print(f"[error] {type(exc).__name__}: {exc}") if args.provide_traceback: traceback.print_exc() raise if __name__ == "__main__": main()