| import argparse |
| import json |
| import os |
| import re |
| import traceback |
| import urllib.error |
| import urllib.request |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import dspy |
| import requests |
| from tqdm import tqdm |
|
|
|
|
| DEFAULT_CLASSIFIER_API_BASE = "http://172.16.34.19:8040/v1" |
| DEFAULT_SUPPORT_API_BASE = "http://172.16.34.19:8090" |
| DEFAULT_MODEL_PATH = ( |
| "/home/mshahidul/readctrl/code/readctrl_rl_inference/model.json" |
| ) |
| DEFAULT_INPUT_FILE = ( |
| "/home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result/vllm_inference_qwen-qwen3-4b-instruct-2507_20260302_065314.jsonl" |
| ) |
| DEFAULT_REFERENCE_SUBCLAIMS_FILE = ( |
| "/home/mshahidul/readctrl/code/text_classifier/data/" |
| "verified_combined_0-80_clean200_with_subclaims.json" |
| ) |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/readctrl_rl_inference/test_result_v5" |
|
|
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
|
|
| |
| MIN_SENTENCE_CHARS = 15 |
|
|
|
|
| |
| |
| |
|
|
| def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]: |
| """Split text at [.!?] boundaries; discard fragments shorter than min_chars.""" |
| if not text or not text.strip(): |
| return [] |
| parts = re.split(r"(?<=[.!?])\s+", text.strip()) |
| return [s.strip() for s in parts if len(s.strip()) >= min_chars] |
|
|
|
|
| |
| |
| |
|
|
| class HealthLiteracySignature(dspy.Signature): |
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
| literacy_label = dspy.OutputField( |
| desc=( |
| "Classification: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), or " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| |
| |
| |
|
|
| class MedicalClaimVerifier: |
| """ |
| Calls the FastAPI /check_support endpoint directly β same approach as |
| reward_new_v5.py. Expects base_url like 'http://host:8090' (NO /v1 suffix). |
| |
| Computes: |
| completeness β fraction of summary_subclaims covered by gen_text (recall) |
| hallucination β fraction of gen_text sentences NOT supported by input_text |
| """ |
|
|
| def __init__(self, base_url: str): |
| self.base_url = base_url.rstrip("/") |
|
|
| |
| def _call_support_api( |
| self, |
| context: str, |
| subclaims: List[str], |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| ) -> Optional[List[str]]: |
| """ |
| POST {base_url}/check_support. |
| Returns list of 'supported'|'not_supported'|'invalid' labels, |
| or None on total network failure (caller can skip the component). |
| """ |
| if not context or not subclaims: |
| return ["invalid"] * len(subclaims) |
| try: |
| api_url = f"{self.base_url}/check_support" |
| payload = { |
| "context": context, |
| "subclaims": subclaims, |
| "threshold": threshold, |
| "batch_size": batch_size, |
| } |
| response = requests.post(api_url, json=payload, timeout=300) |
| response.raise_for_status() |
| result = response.json() |
| labels = result.get("labels", ["invalid"] * len(subclaims)) |
| if len(labels) < len(subclaims): |
| labels.extend(["invalid"] * (len(subclaims) - len(labels))) |
| elif len(labels) > len(subclaims): |
| labels = labels[: len(subclaims)] |
| return labels |
| except requests.exceptions.RequestException as exc: |
| print(f"Warning: Support API call failed (returning None): {exc}") |
| return None |
|
|
| |
| def compute_completeness( |
| self, |
| summary_subclaims: List[str], |
| gen_text: str, |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| ) -> Optional[float]: |
| """ |
| Completeness β [0, 1]: fraction of summary_subclaims covered by gen_text. |
| Recall direction: subclaims = summary sentences, context = gen_text. |
| Returns None on total API failure. |
| """ |
| if not summary_subclaims: |
| return 0.0 |
| if not gen_text or not gen_text.strip(): |
| return 0.0 |
|
|
| labels = self._call_support_api( |
| context=gen_text, |
| subclaims=summary_subclaims, |
| threshold=threshold, |
| batch_size=batch_size, |
| ) |
| if labels is None: |
| print("Warning: completeness API failure β skipping component.") |
| return None |
|
|
| valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] |
| if not valid_labels: |
| print("Warning: all completeness labels were 'invalid' β skipping.") |
| return None |
|
|
| covered = sum(1 for lbl in valid_labels if str(lbl).strip().lower() == "supported") |
| return covered / len(valid_labels) |
|
|
| def compute_hallucination( |
| self, |
| input_text: str, |
| gen_text: str, |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| ) -> Optional[float]: |
| """ |
| Hallucination β [0, 1]: fraction of gen_text sentences NOT supported by |
| input_text. Uses stable denominator = max(n_gen, n_input) to prevent |
| padding inflation β mirrors reward_new_v5.py. |
| Returns None on total API failure. |
| """ |
| gen_segments = _split_into_sentences(gen_text) |
| if not gen_segments or not input_text or not input_text.strip(): |
| return 0.0 |
|
|
| input_sentences = _split_into_sentences(input_text) |
| stable_denom = max(len(gen_segments), len(input_sentences)) |
| if stable_denom == 0: |
| return 0.0 |
|
|
| labels = self._call_support_api( |
| context=input_text, |
| subclaims=gen_segments, |
| threshold=threshold, |
| batch_size=batch_size, |
| ) |
| if labels is None: |
| print("Warning: hallucination API failure β skipping component.") |
| return None |
|
|
| valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] |
| if not valid_labels: |
| print("Warning: all hallucination labels were 'invalid' β skipping.") |
| return None |
|
|
| hallucinated = sum( |
| 1 for lbl in valid_labels if str(lbl).strip().lower() != "supported" |
| ) |
| return hallucinated / stable_denom |
|
|
| def evaluate_sample( |
| self, |
| gen_text: str, |
| summary_subclaims: List[str], |
| input_text: str, |
| ) -> Tuple[Optional[float], Optional[float]]: |
| """ |
| Returns (completeness_score, hallucination_score). |
| Either can be None if the API failed for that component. |
| """ |
| completeness = self.compute_completeness( |
| summary_subclaims=summary_subclaims, |
| gen_text=gen_text, |
| ) |
| hallucination = self.compute_hallucination( |
| input_text=input_text, |
| gen_text=gen_text, |
| ) |
| return completeness, hallucination |
|
|
|
|
| |
| |
| |
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Evaluate classifier accuracy + completeness (recall) + " |
| "hallucination score β mirrors reward_new_v5.py." |
| ) |
| ) |
| parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| parser.add_argument( |
| "--input-file", |
| default=DEFAULT_INPUT_FILE, |
| help="Path to RL inference JSONL.", |
| ) |
| parser.add_argument( |
| "--reference-subclaims-file", |
| default=DEFAULT_REFERENCE_SUBCLAIMS_FILE, |
| help=( |
| "JSON list with summary_subclaims + input_text keyed by (doc_id, label)." |
| ), |
| ) |
| parser.add_argument( |
| "--classifier-api-base", |
| default=os.environ.get("VLLM_API_BASE", DEFAULT_CLASSIFIER_API_BASE), |
| ) |
| parser.add_argument( |
| "--support-api-base", |
| default=os.environ.get("SUPPORT_API_BASE", DEFAULT_SUPPORT_API_BASE), |
| help="FastAPI /check_support base URL (NO /v1 suffix).", |
| ) |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument( |
| "--generated-text-key", |
| default="generated_text", |
| help="Field name for generated text in input JSONL.", |
| ) |
| parser.add_argument( |
| "--comp-threshold", |
| type=float, |
| default=0.5, |
| help="Completeness pass threshold (score >= this value counts as pass).", |
| ) |
| parser.add_argument( |
| "--hallucination-threshold", |
| type=float, |
| default=0.1, |
| help="Hallucination fail threshold (score > this value counts as fail).", |
| ) |
| parser.add_argument( |
| "--max-samples", |
| type=int, |
| default=-1, |
| help="Use -1 for all rows.", |
| ) |
| parser.add_argument( |
| "--provide-traceback", |
| action="store_true", |
| help="Print full traceback on runtime error.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def check_api_base(api_base: str) -> None: |
| """Health-check for the OpenAI-compatible /models endpoint (classifier).""" |
| models_url = api_base.rstrip("/") + "/models" |
| req = urllib.request.Request(models_url, method="GET") |
| try: |
| with urllib.request.urlopen(req, timeout=5) as resp: |
| if resp.status >= 400: |
| raise RuntimeError( |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| ) |
| except urllib.error.URLError as exc: |
| raise ConnectionError( |
| "Cannot reach OpenAI-compatible endpoint. " |
| f"api_base={api_base}. " |
| "Start your vLLM server or pass correct api base." |
| ) from exc |
|
|
|
|
| def check_support_api_base(api_base: str) -> None: |
| """Health-check for the FastAPI /check_support endpoint.""" |
| url = api_base.rstrip("/") + "/check_support" |
| try: |
| resp = requests.post( |
| url, |
| json={"context": "test", "subclaims": ["test"], "threshold": 0.5, "batch_size": 1}, |
| timeout=5, |
| ) |
| if resp.status_code >= 500: |
| raise RuntimeError( |
| f"Support API server error: {url} (status={resp.status_code})" |
| ) |
| except requests.exceptions.ConnectionError as exc: |
| raise ConnectionError( |
| f"Cannot reach Support API: {url}. Ensure the FastAPI server is running." |
| ) from exc |
| except requests.exceptions.Timeout as exc: |
| raise ConnectionError(f"Support API timed out: {url}") from exc |
|
|
|
|
| |
| |
| |
|
|
| def load_compiled_classifier(path: str): |
| if hasattr(dspy, "load"): |
| try: |
| return dspy.load(path) |
| except Exception: |
| pass |
| classifier = HealthLiteracyClassifier() |
| try: |
| classifier.load(path) |
| except Exception as exc: |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| return classifier |
|
|
|
|
| def normalize_pred_label(pred_obj: Any) -> str: |
| if not pred_obj or not hasattr(pred_obj, "literacy_label"): |
| return "" |
| return str(pred_obj.literacy_label).strip().lower() |
|
|
|
|
| def load_items(path: str, generated_text_key: str) -> List[Dict[str, Any]]: |
| items: List[Dict[str, Any]] = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line_no, line in enumerate(f, start=1): |
| if not line.strip(): |
| continue |
| row = json.loads(line) |
| generated_text = str( |
| row.get(generated_text_key, row.get("generated_text", "")) |
| ).strip() |
| items.append( |
| { |
| "line_no": line_no, |
| "row_index": row.get("row_index"), |
| "doc_id": row.get("doc_id"), |
| "gold_label": str(row.get("gold_label", "")).strip(), |
| "generated_text": generated_text, |
| |
| "input_text": str(row.get("input_text", "")).strip(), |
| } |
| ) |
| return items |
|
|
|
|
| def load_reference_lookup( |
| reference_path: str, |
| ) -> Dict[Tuple[Any, str], Dict[str, Any]]: |
| """ |
| Returns a lookup keyed by (doc_id, label) β dict with: |
| summary_subclaims : List[str] β used for completeness |
| input_text : str β used for hallucination |
| """ |
| with open(reference_path, "r", encoding="utf-8") as f: |
| rows = json.load(f) |
| if not isinstance(rows, list): |
| raise ValueError("Reference file must be a JSON list.") |
|
|
| lookup: Dict[Tuple[Any, str], Dict[str, Any]] = {} |
| valid_label_rows = 0 |
| rows_with_keys = 0 |
|
|
| for row in rows: |
| doc_id = row.get("doc_id") |
| label = str(row.get("label", "")).strip() |
| if label not in VALID_LABELS: |
| continue |
| valid_label_rows += 1 |
|
|
| summary_subclaims = row.get("summary_subclaims", row.get("gold_subclaims", [])) |
| input_text = str(row.get("input_text", row.get("fulltext", ""))).strip() |
|
|
| if not isinstance(summary_subclaims, list) or not summary_subclaims: |
| continue |
| rows_with_keys += 1 |
|
|
| entry = {"summary_subclaims": summary_subclaims, "input_text": input_text} |
| for key in [(doc_id, label), (str(doc_id), label)]: |
| if key not in lookup: |
| lookup[key] = entry |
|
|
| if not lookup: |
| raise ValueError( |
| "Reference lookup is empty. Expected JSON rows with " |
| "`summary_subclaims` list fields keyed by (doc_id, label). " |
| f"valid_label_rows={valid_label_rows}, " |
| f"rows_with_keys={rows_with_keys}, " |
| f"reference_path={reference_path}" |
| ) |
| return lookup |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| args = parse_args() |
| if not os.path.exists(args.model_path): |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") |
| if not os.path.exists(args.input_file): |
| raise FileNotFoundError(f"Input file not found: {args.input_file}") |
| if not os.path.exists(args.reference_subclaims_file): |
| raise FileNotFoundError( |
| f"Reference file not found: {args.reference_subclaims_file}" |
| ) |
|
|
| try: |
| check_api_base(args.classifier_api_base) |
| check_support_api_base(args.support_api_base) |
|
|
| lm = dspy.LM( |
| model="openai/dspy", |
| api_base=args.classifier_api_base, |
| api_key="EMPTY", |
| temperature=0.0, |
| ) |
| dspy.configure(lm=lm) |
| classifier = load_compiled_classifier(args.model_path) |
| verifier = MedicalClaimVerifier(base_url=args.support_api_base) |
| reference_lookup = load_reference_lookup(args.reference_subclaims_file) |
|
|
| rows = load_items(args.input_file, args.generated_text_key) |
| if args.max_samples > 0: |
| rows = rows[: args.max_samples] |
|
|
| |
| unmatched_rows = 0 |
| total = 0 |
| classifier_correct = 0 |
| comp_pass_count = 0 |
| halluc_fail_count = 0 |
| cls_and_comp_pass_count = 0 |
| cls_comp_no_halluc_count = 0 |
|
|
| |
| comp_sum = 0.0 |
| comp_n = 0 |
| halluc_sum = 0.0 |
| halluc_n = 0 |
|
|
| details: List[Dict[str, Any]] = [] |
|
|
| CHECKPOINT_EVERY = 10 |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| summary_path = os.path.join( |
| args.output_dir, f"classifier_subclaim_threshold_eval_{ts}.json" |
| ) |
| details_path = os.path.join( |
| args.output_dir, f"classifier_subclaim_threshold_eval_{ts}.jsonl" |
| ) |
|
|
| def build_summary() -> Dict[str, Any]: |
| safe_rate = lambda n: n / total if total else 0.0 |
| return { |
| "model_path": args.model_path, |
| "input_file": args.input_file, |
| "reference_subclaims_file": args.reference_subclaims_file, |
| "generated_text_key": args.generated_text_key, |
| "classifier_api_base": args.classifier_api_base, |
| "support_api_base": args.support_api_base, |
| "total_samples": total, |
| "unmatched_rows": unmatched_rows, |
| |
| "classifier_only_accuracy": safe_rate(classifier_correct), |
| |
| "completeness_pass_rate": safe_rate(comp_pass_count), |
| "completeness_mean": comp_sum / comp_n if comp_n else None, |
| "completeness_threshold": args.comp_threshold, |
| |
| "hallucination_fail_rate": safe_rate(halluc_fail_count), |
| "hallucination_mean": halluc_sum / halluc_n if halluc_n else None, |
| "hallucination_threshold": args.hallucination_threshold, |
| |
| "accuracy_cls_and_completeness": safe_rate(cls_and_comp_pass_count), |
| "accuracy_cls_comp_no_hallucination": safe_rate(cls_comp_no_halluc_count), |
| "details_path": details_path, |
| } |
|
|
| def save_checkpoint() -> None: |
| with open(summary_path, "w", encoding="utf-8") as f_sum: |
| json.dump(build_summary(), f_sum, indent=2) |
| with open(details_path, "w", encoding="utf-8") as f_det: |
| for item in details: |
| f_det.write(json.dumps(item, ensure_ascii=False) + "\n") |
|
|
| |
| for idx, row in enumerate(tqdm(rows, desc="Evaluating"), start=1): |
| gold_label = str(row.get("gold_label", "")).strip() |
| if gold_label not in VALID_LABELS: |
| continue |
|
|
| generated_text = str(row.get("generated_text", "")).strip() |
| doc_id = row.get("doc_id") |
|
|
| ref = reference_lookup.get((doc_id, gold_label)) or reference_lookup.get( |
| (str(doc_id), gold_label) |
| ) |
| if not generated_text or not ref: |
| if not ref: |
| unmatched_rows += 1 |
| continue |
|
|
| summary_subclaims = ref["summary_subclaims"] |
| |
| input_text = ref.get("input_text") or row.get("input_text", "") |
|
|
| total += 1 |
|
|
| |
| pred = classifier(generated_text=generated_text) |
| pred_label = normalize_pred_label(pred) |
| is_cls_correct = gold_label in pred_label |
| classifier_correct += int(is_cls_correct) |
|
|
| |
| comp_score, halluc_score = verifier.evaluate_sample( |
| gen_text=generated_text, |
| summary_subclaims=summary_subclaims, |
| input_text=input_text, |
| ) |
|
|
| |
| comp_pass = (comp_score is not None) and (comp_score >= args.comp_threshold) |
| comp_pass_count += int(comp_pass) |
| if comp_score is not None: |
| comp_sum += comp_score |
| comp_n += 1 |
|
|
| |
| halluc_fail = (halluc_score is not None) and (halluc_score > args.hallucination_threshold) |
| halluc_fail_count += int(halluc_fail) |
| if halluc_score is not None: |
| halluc_sum += halluc_score |
| halluc_n += 1 |
|
|
| |
| cls_and_comp = is_cls_correct and comp_pass |
| cls_comp_no_halluc = cls_and_comp and not halluc_fail |
| cls_and_comp_pass_count += int(cls_and_comp) |
| cls_comp_no_halluc_count += int(cls_comp_no_halluc) |
|
|
| details.append( |
| { |
| "idx": idx, |
| "line_no": row.get("line_no"), |
| "row_index": row.get("row_index"), |
| "doc_id": doc_id, |
| "gold_label": gold_label, |
| "generated_text": generated_text, |
| "pred_label": pred_label, |
| "classifier_correct": is_cls_correct, |
| "completeness_score": comp_score, |
| "completeness_pass": comp_pass, |
| "completeness_threshold": args.comp_threshold, |
| "hallucination_score": halluc_score, |
| "hallucination_fail": halluc_fail, |
| "hallucination_threshold": args.hallucination_threshold, |
| "pass_cls_and_completeness": cls_and_comp, |
| "pass_cls_comp_no_hallucination": cls_comp_no_halluc, |
| } |
| ) |
|
|
| if total % CHECKPOINT_EVERY == 0: |
| save_checkpoint() |
| comp_avg = f"{comp_sum/comp_n:.4f}" if comp_n else "N/A" |
| halluc_avg = f"{halluc_sum/halluc_n:.4f}" if halluc_n else "N/A" |
| print( |
| f"\n[CHECKPOINT] {total} samples β " |
| f"cls_acc={classifier_correct/total:.4f}, " |
| f"comp_pass={comp_pass_count/total:.4f} (mean={comp_avg}), " |
| f"halluc_fail={halluc_fail_count/total:.4f} (mean={halluc_avg})" |
| ) |
|
|
| if total == 0: |
| raise RuntimeError("No valid rows were found for evaluation.") |
|
|
| save_checkpoint() |
|
|
| summary = build_summary() |
| print(json.dumps(summary, indent=2)) |
| print(f"[DONE] Summary saved: {summary_path}") |
| print(f"[DONE] Details saved: {details_path}") |
|
|
| except Exception as exc: |
| print(f"[error] {type(exc).__name__}: {exc}") |
| if args.provide_traceback: |
| traceback.print_exc() |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|