| import argparse |
| import glob |
| import json |
| import os |
| import traceback |
| import urllib.error |
| import urllib.request |
| from datetime import datetime |
| from typing import Any, Dict, List |
|
|
| import dspy |
| from tqdm import tqdm |
|
|
|
|
| DEFAULT_API_BASE = "http://172.16.34.21:8040/v1" |
| DEFAULT_MODEL_PATH = ( |
| "/home/mshahidul/readctrl/code/text_classifier/" |
| "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" |
| ) |
| DEFAULT_INPUT_PATH = "/home/mshahidul/readctrl/code/RL_model/inference_data" |
| DEFAULT_INPUT_FILE = ( |
| "/home/mshahidul/readctrl/code/RL_model/inference_data/" |
| "vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334.jsonl" |
| ) |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" |
|
|
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
|
|
|
|
| class HealthLiteracySignature(dspy.Signature): |
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
| literacy_label = dspy.OutputField( |
| desc=( |
| "Classification: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), or " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Evaluate saved DSPy classifier on saved vLLM inference outputs." |
| ) |
| parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| parser.add_argument( |
| "--input-path", |
| default=DEFAULT_INPUT_FILE, |
| help=( |
| "Path to vLLM output JSONL (e.g. vllm_inference_*.jsonl). " |
| "Set to empty string to auto-select latest file in --search-dir." |
| ), |
| ) |
| parser.add_argument( |
| "--search-dir", |
| default=DEFAULT_INPUT_PATH, |
| help="Directory to auto-search for vllm_inference_*.jsonl", |
| ) |
| parser.add_argument( |
| "--api-base", |
| default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE), |
| ) |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument( |
| "--max-samples", |
| type=int, |
| default=-1, |
| help="Use -1 for all rows.", |
| ) |
| parser.add_argument( |
| "--provide-traceback", |
| action="store_true", |
| help="Print full traceback if runtime error happens.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def check_api_base(api_base: str) -> None: |
| models_url = api_base.rstrip("/") + "/models" |
| req = urllib.request.Request(models_url, method="GET") |
| try: |
| with urllib.request.urlopen(req, timeout=5) as resp: |
| if resp.status >= 400: |
| raise RuntimeError( |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| ) |
| except urllib.error.URLError as exc: |
| raise ConnectionError( |
| "Cannot reach OpenAI-compatible endpoint. " |
| f"api_base={api_base}. " |
| "Start your vLLM server or pass correct --api-base." |
| ) from exc |
|
|
|
|
| def resolve_input_path(input_path: str, search_dir: str) -> str: |
| if input_path and os.path.exists(input_path): |
| return input_path |
| if input_path: |
| raise FileNotFoundError(f"Input file not found: {input_path}") |
|
|
| candidates = sorted( |
| glob.glob(os.path.join(search_dir, "vllm_inference_*.jsonl")), |
| key=os.path.getmtime, |
| ) |
| if not candidates: |
| raise FileNotFoundError( |
| "No vLLM output file found. Expected pattern: " |
| f"{search_dir}/vllm_inference_*.jsonl" |
| ) |
| return candidates[-1] |
|
|
|
|
| def load_compiled_classifier(path: str): |
| if hasattr(dspy, "load"): |
| try: |
| return dspy.load(path) |
| except Exception: |
| pass |
|
|
| classifier = HealthLiteracyClassifier() |
| try: |
| classifier.load(path) |
| except Exception as exc: |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| return classifier |
|
|
|
|
| def normalize_pred_label(pred_obj: Any) -> str: |
| if not pred_obj or not hasattr(pred_obj, "literacy_label"): |
| return "" |
| return str(pred_obj.literacy_label).strip().lower() |
|
|
|
|
| def load_eval_items(path: str) -> List[Dict[str, Any]]: |
| items: List[Dict[str, Any]] = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line_no, line in enumerate(f, start=1): |
| if not line.strip(): |
| continue |
| row = json.loads(line) |
| gold_label = str(row.get("gold_label", "")).strip() |
| generated_text = str(row.get("generated_text", "")).strip() |
| if not generated_text: |
| generated_text = str(row.get("prediction", "")).strip() |
| err_msg = str(row.get("error", "")).strip() |
|
|
| if gold_label not in VALID_LABELS: |
| continue |
| if err_msg: |
| continue |
| if not generated_text: |
| continue |
|
|
| items.append( |
| { |
| "line_no": line_no, |
| "row_index": row.get("row_index"), |
| "doc_id": row.get("doc_id"), |
| "gold_label": gold_label, |
| "generated_text": generated_text, |
| } |
| ) |
| return items |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.input_path = resolve_input_path(args.input_path, args.search_dir) |
|
|
| if not os.path.exists(args.model_path): |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") |
|
|
| try: |
| check_api_base(args.api_base) |
| lm = dspy.LM( |
| model="openai/dspy", |
| api_base=args.api_base, |
| api_key="EMPTY", |
| temperature=0.0, |
| ) |
| dspy.configure(lm=lm) |
| classifier = load_compiled_classifier(args.model_path) |
| print(f"[INFO] Using input file: {args.input_path}") |
| parsed_items = load_eval_items(args.input_path) |
| if args.max_samples > 0: |
| parsed_items = parsed_items[: args.max_samples] |
|
|
| if not parsed_items: |
| raise RuntimeError("No valid rows found in input file for classifier evaluation.") |
|
|
| correct = 0 |
| results: List[Dict[str, Any]] = [] |
| for item in tqdm(parsed_items, desc="Classifying"): |
| pred = classifier(generated_text=item["generated_text"]) |
| pred_label = normalize_pred_label(pred) |
| is_correct = item["gold_label"] in pred_label |
| correct += int(is_correct) |
| results.append( |
| { |
| "line_no": item["line_no"], |
| "row_index": item["row_index"], |
| "doc_id": item.get("doc_id"), |
| "gold_label": item["gold_label"], |
| "pred_label": pred_label, |
| "is_correct": is_correct, |
| } |
| ) |
|
|
| total = len(results) |
| accuracy = correct / total if total else 0.0 |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| os.makedirs(args.output_dir, exist_ok=True) |
| summary_path = os.path.join(args.output_dir, f"classifier_eval_vllm_{ts}.json") |
| details_path = os.path.join(args.output_dir, f"classifier_eval_vllm_{ts}.jsonl") |
|
|
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump( |
| { |
| "model_path": args.model_path, |
| "input_path": args.input_path, |
| "api_base": args.api_base, |
| "total_samples": total, |
| "correct_samples": correct, |
| "accuracy_score": accuracy, |
| "details_path": details_path, |
| }, |
| f, |
| indent=2, |
| ) |
|
|
| with open(details_path, "w", encoding="utf-8") as f: |
| for r in results: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
| print(json.dumps({"total_samples": total, "accuracy_score": accuracy}, indent=2)) |
| print(f"[DONE] Summary saved: {summary_path}") |
| print(f"[DONE] Details saved: {details_path}") |
|
|
| except Exception as exc: |
| print(f"[error] {type(exc).__name__}: {exc}") |
| if args.provide_traceback: |
| traceback.print_exc() |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|