| import argparse | |
| import json | |
| import os | |
| import traceback | |
| import urllib.error | |
| import urllib.request | |
| import dspy | |
| from dspy.evaluate import Evaluate | |
| DEFAULT_API_BASE = "http://172.16.34.22:8040/v1" | |
| DEFAULT_MODEL_PATH = ( | |
| "/home/mshahidul/readctrl/code/text_classifier/dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" | |
| ) | |
| DEFAULT_TEST_PATH = "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json" | |
| DEFAULT_OUTPUT_PATH = ( | |
| "/home/mshahidul/readctrl/code/text_classifier/accuracy/" | |
| "vllm-llama-3.1-8b-awq-int4_teacher-gpt5_v1_clean200_eval.json" | |
| ) | |
| class HealthLiteracySignature(dspy.Signature): | |
| generated_text = dspy.InputField( | |
| desc="A version of the source text rewritten for a specific audience." | |
| ) | |
| literacy_label = dspy.OutputField( | |
| desc=( | |
| "Classification: low_health_literacy (simple words, no jargon), " | |
| "intermediate_health_literacy (moderate technicality), or " | |
| "proficient_health_literacy (highly technical/original level)." | |
| ) | |
| ) | |
| class HealthLiteracyClassifier(dspy.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) | |
| def forward(self, generated_text): | |
| return self.classifier(generated_text=generated_text) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Load a saved DSPy model and evaluate on test set." | |
| ) | |
| parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) | |
| parser.add_argument("--test-path", default=DEFAULT_TEST_PATH) | |
| parser.add_argument( | |
| "--api-base", | |
| default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE), | |
| ) | |
| parser.add_argument("--num-threads", type=int, default=1) | |
| parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH) | |
| parser.add_argument( | |
| "--provide-traceback", | |
| action="store_true", | |
| help="Print full traceback if runtime error happens.", | |
| ) | |
| return parser.parse_args() | |
| def check_api_base(api_base): | |
| models_url = api_base.rstrip("/") + "/models" | |
| req = urllib.request.Request(models_url, method="GET") | |
| try: | |
| with urllib.request.urlopen(req, timeout=5) as resp: | |
| if resp.status >= 400: | |
| raise RuntimeError( | |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" | |
| ) | |
| except urllib.error.URLError as exc: | |
| raise ConnectionError( | |
| "Cannot reach OpenAI-compatible endpoint. " | |
| f"api_base={api_base}. " | |
| "Start your vLLM server or pass correct --api-base." | |
| ) from exc | |
| def load_testset(path): | |
| examples = [] | |
| if path.endswith(".jsonl"): | |
| with open(path, "r") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| record = json.loads(line) | |
| example = dspy.Example( | |
| generated_text=record["generated_text"], | |
| literacy_label=record["literacy_label"], | |
| ).with_inputs("generated_text") | |
| examples.append(example) | |
| else: | |
| with open(path, "r") as f: | |
| records = json.load(f) | |
| for record in records: | |
| text = record.get("generated_text", record.get("diff_label_texts")) | |
| label = record.get("literacy_label", record.get("label")) | |
| if not text or not label: | |
| continue | |
| example = dspy.Example( | |
| generated_text=text, | |
| literacy_label=label, | |
| ).with_inputs("generated_text") | |
| examples.append(example) | |
| return examples | |
| def health_literacy_metric(gold, pred, trace=None): | |
| if not pred or not hasattr(pred, "literacy_label"): | |
| return False | |
| gold_label = str(gold.literacy_label).strip().lower() | |
| pred_label = str(pred.literacy_label).strip().lower() | |
| return gold_label in pred_label | |
| def load_compiled_classifier(path): | |
| if hasattr(dspy, "load"): | |
| try: | |
| return dspy.load(path) | |
| except Exception as exc: | |
| print( | |
| f"[warning] dspy.load failed ({type(exc).__name__}); " | |
| "trying module.load(...)" | |
| ) | |
| classifier = HealthLiteracyClassifier() | |
| try: | |
| classifier.load(path) | |
| except Exception as exc: | |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc | |
| return classifier | |
| def main(): | |
| args = parse_args() | |
| if not os.path.exists(args.model_path): | |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") | |
| if not os.path.exists(args.test_path): | |
| raise FileNotFoundError(f"Test file not found: {args.test_path}") | |
| try: | |
| check_api_base(args.api_base) | |
| lm = dspy.LM( | |
| model="openai/dspy", | |
| api_base=args.api_base, | |
| api_key="EMPTY", | |
| temperature=0.0, | |
| ) | |
| dspy.configure(lm=lm) | |
| testset = load_testset(args.test_path) | |
| compiled_classifier = load_compiled_classifier(args.model_path) | |
| evaluator = Evaluate( | |
| devset=testset, | |
| metric=health_literacy_metric, | |
| num_threads=args.num_threads, | |
| display_progress=True, | |
| ) | |
| evaluation_result = evaluator(compiled_classifier) | |
| accuracy_score = ( | |
| float(evaluation_result.score) | |
| if hasattr(evaluation_result, "score") | |
| else float(evaluation_result) | |
| ) | |
| output_data = { | |
| "model_path": args.model_path, | |
| "test_path": args.test_path, | |
| "accuracy_score": accuracy_score, | |
| "num_results": len(getattr(evaluation_result, "results", []) or []), | |
| } | |
| os.makedirs(os.path.dirname(args.output_path), exist_ok=True) | |
| with open(args.output_path, "w") as f: | |
| json.dump(output_data, f, indent=2) | |
| print(evaluation_result) | |
| print(json.dumps(output_data, indent=2)) | |
| except Exception as exc: | |
| print(f"[error] {type(exc).__name__}: {exc}") | |
| if args.provide_traceback: | |
| traceback.print_exc() | |
| raise | |
| if __name__ == "__main__": | |
| main() | |