| | import argparse |
| | import glob |
| | import json |
| | import os |
| | import traceback |
| | import urllib.error |
| | import urllib.request |
| | from collections import defaultdict |
| | from datetime import datetime |
| | from typing import Any, DefaultDict, Dict, List |
| |
|
| | import dspy |
| | from tqdm import tqdm |
| |
|
| |
|
| | DEFAULT_API_BASE = "http://172.16.34.22:8040/v1" |
| | DEFAULT_MODEL_PATH = ( |
| | "/home/mshahidul/readctrl/code/text_classifier/" |
| | "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" |
| | ) |
| | DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" |
| |
|
| | VALID_LABELS = { |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | } |
| |
|
| |
|
| | class HealthLiteracySignature(dspy.Signature): |
| | generated_text = dspy.InputField( |
| | desc="A version of the source text rewritten for a specific audience." |
| | ) |
| | literacy_label = dspy.OutputField( |
| | desc=( |
| | "Classification: low_health_literacy (simple words, no jargon), " |
| | "intermediate_health_literacy (moderate technicality), or " |
| | "proficient_health_literacy (highly technical/original level)." |
| | ) |
| | ) |
| |
|
| |
|
| | class HealthLiteracyClassifier(dspy.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
| |
|
| | def forward(self, generated_text): |
| | return self.classifier(generated_text=generated_text) |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Evaluate GPT output files with saved DSPy health literacy classifier." |
| | ) |
| | parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| | parser.add_argument( |
| | "--input-path", |
| | default="", |
| | help=( |
| | "Path to GPT output JSONL (e.g. gpt5_inference_all_*.jsonl). " |
| | "If omitted, auto-select latest file in test_result." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--api-base", |
| | default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE), |
| | ) |
| | parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| | parser.add_argument( |
| | "--max-samples", |
| | type=int, |
| | default=-1, |
| | help="Use -1 for all valid rows.", |
| | ) |
| | parser.add_argument( |
| | "--provide-traceback", |
| | action="store_true", |
| | help="Print full traceback if runtime error happens.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def resolve_input_path(input_path: str, search_dir: str) -> str: |
| | if input_path and os.path.exists(input_path): |
| | return input_path |
| | if input_path: |
| | raise FileNotFoundError(f"Input file not found: {input_path}") |
| |
|
| | candidates = sorted(glob.glob(os.path.join(search_dir, "gpt5_inference_all_*.jsonl")), key=os.path.getmtime) |
| | if not candidates: |
| | |
| | candidates = sorted( |
| | glob.glob(os.path.join(search_dir, "gpt5_inference_*_*.jsonl")), |
| | key=os.path.getmtime, |
| | ) |
| | if not candidates: |
| | raise FileNotFoundError( |
| | "No GPT output file found. Expected pattern: " |
| | f"{search_dir}/gpt5_inference_all_*.jsonl " |
| | "or gpt5_inference_*_*.jsonl" |
| | ) |
| | return candidates[-1] |
| |
|
| |
|
| | def check_api_base(api_base: str) -> None: |
| | models_url = api_base.rstrip("/") + "/models" |
| | req = urllib.request.Request(models_url, method="GET") |
| | try: |
| | with urllib.request.urlopen(req, timeout=5) as resp: |
| | if resp.status >= 400: |
| | raise RuntimeError( |
| | f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| | ) |
| | except urllib.error.URLError as exc: |
| | raise ConnectionError( |
| | "Cannot reach OpenAI-compatible endpoint. " |
| | f"api_base={api_base}. " |
| | "Start your vLLM server or pass correct --api-base." |
| | ) from exc |
| |
|
| |
|
| | def load_compiled_classifier(path: str): |
| | if hasattr(dspy, "load"): |
| | try: |
| | return dspy.load(path) |
| | except Exception: |
| | pass |
| |
|
| | classifier = HealthLiteracyClassifier() |
| | try: |
| | classifier.load(path) |
| | except Exception as exc: |
| | raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| | return classifier |
| |
|
| |
|
| | def normalize_pred_label(pred_obj: Any) -> str: |
| | if not pred_obj or not hasattr(pred_obj, "literacy_label"): |
| | return "" |
| | return str(pred_obj.literacy_label).strip().lower() |
| |
|
| |
|
| | def load_eval_items(path: str) -> List[Dict[str, Any]]: |
| | items: List[Dict[str, Any]] = [] |
| | with open(path, "r", encoding="utf-8") as f: |
| | for line_no, line in enumerate(f, start=1): |
| | if not line.strip(): |
| | continue |
| | row = json.loads(line) |
| | gold_label = str(row.get("gold_label", "")).strip() |
| | generated_text = str(row.get("generated_text", "")).strip() |
| | err_msg = str(row.get("error", "")).strip() |
| |
|
| | if gold_label not in VALID_LABELS: |
| | continue |
| | if err_msg: |
| | continue |
| | if not generated_text: |
| | continue |
| |
|
| | items.append( |
| | { |
| | "line_no": line_no, |
| | "model": str(row.get("model", "")).strip() or "unknown_model", |
| | "row_index": row.get("row_index"), |
| | "doc_id": row.get("doc_id"), |
| | "gold_label": gold_label, |
| | "generated_text": generated_text, |
| | } |
| | ) |
| | return items |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | args.input_path = resolve_input_path(args.input_path, args.output_dir) |
| |
|
| | if not os.path.exists(args.model_path): |
| | raise FileNotFoundError(f"Model file not found: {args.model_path}") |
| |
|
| | try: |
| | check_api_base(args.api_base) |
| | lm = dspy.LM( |
| | model="openai/dspy", |
| | api_base=args.api_base, |
| | api_key="EMPTY", |
| | temperature=0.0, |
| | ) |
| | dspy.configure(lm=lm) |
| | classifier = load_compiled_classifier(args.model_path) |
| | print(f"[INFO] Using input file: {args.input_path}") |
| |
|
| | eval_items = load_eval_items(args.input_path) |
| | if args.max_samples > 0: |
| | eval_items = eval_items[: args.max_samples] |
| | if not eval_items: |
| | raise RuntimeError("No valid rows found for evaluation.") |
| |
|
| | results: List[Dict[str, Any]] = [] |
| | model_total: DefaultDict[str, int] = defaultdict(int) |
| | model_correct: DefaultDict[str, int] = defaultdict(int) |
| |
|
| | for item in tqdm(eval_items, desc="Classifying"): |
| | pred = classifier(generated_text=item["generated_text"]) |
| | pred_label = normalize_pred_label(pred) |
| | is_correct = item["gold_label"] in pred_label |
| |
|
| | model_name = item["model"] |
| | model_total[model_name] += 1 |
| | model_correct[model_name] += int(is_correct) |
| |
|
| | results.append( |
| | { |
| | "line_no": item["line_no"], |
| | "model": model_name, |
| | "row_index": item["row_index"], |
| | "doc_id": item["doc_id"], |
| | "gold_label": item["gold_label"], |
| | "pred_label": pred_label, |
| | "is_correct": is_correct, |
| | } |
| | ) |
| |
|
| | total = len(results) |
| | correct = sum(1 for r in results if r["is_correct"]) |
| | overall_accuracy = correct / total if total else 0.0 |
| |
|
| | per_model: Dict[str, Dict[str, Any]] = {} |
| | for model_name in sorted(model_total.keys()): |
| | m_total = model_total[model_name] |
| | m_correct = model_correct[model_name] |
| | per_model[model_name] = { |
| | "total_samples": m_total, |
| | "correct_samples": m_correct, |
| | "accuracy_score": (m_correct / m_total) if m_total else 0.0, |
| | } |
| |
|
| | ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | summary_path = os.path.join(args.output_dir, f"classifier_eval_gpt5_{ts}.json") |
| | details_path = os.path.join(args.output_dir, f"classifier_eval_gpt5_{ts}.jsonl") |
| |
|
| | summary_obj = { |
| | "model_path": args.model_path, |
| | "input_path": args.input_path, |
| | "api_base": args.api_base, |
| | "total_samples": total, |
| | "correct_samples": correct, |
| | "accuracy_score": overall_accuracy, |
| | "per_model": per_model, |
| | "details_path": details_path, |
| | } |
| |
|
| | with open(summary_path, "w", encoding="utf-8") as f: |
| | json.dump(summary_obj, f, indent=2, ensure_ascii=False) |
| |
|
| | with open(details_path, "w", encoding="utf-8") as f: |
| | for record in results: |
| | f.write(json.dumps(record, ensure_ascii=False) + "\n") |
| |
|
| | print(json.dumps(summary_obj, indent=2, ensure_ascii=False)) |
| | print(f"[DONE] Summary saved: {summary_path}") |
| | print(f"[DONE] Details saved: {details_path}") |
| |
|
| | except Exception as exc: |
| | print(f"[error] {type(exc).__name__}: {exc}") |
| | if args.provide_traceback: |
| | traceback.print_exc() |
| | raise |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|