| import os |
| import re |
| import json |
| import argparse |
| from typing import Any, List, Dict |
| import warnings |
| import time |
| import requests |
| test_mode = True |
| warnings.filterwarnings("ignore") |
| test_mode = False |
| try: |
| import dspy |
| except ImportError: |
| dspy = None |
|
|
| SUPPORT_API_BASE = os.getenv("SUPPORT_API_BASE", "http://172.16.34.19:8090") |
|
|
|
|
| |
| |
| |
|
|
| def _call_support_api( |
| context: str, |
| subclaims: List[str], |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| max_retries: int = 3, |
| initial_retry_delay: float = 5.0, |
| backoff_factor: float = 2.0, |
| ) -> List[str]: |
| """ |
| Call the FastAPI /check_support endpoint. |
| |
| Returns |
| ------- |
| List[str] : one label per subclaim β "supported" | "not_supported" | "invalid". |
| None : returned on a TOTAL network/transport failure, so callers can |
| distinguish a genuine API error from a valid "not_supported" label |
| and avoid applying a false penalty. |
| """ |
| if not context or not subclaims: |
| return ["invalid"] * len(subclaims) |
|
|
| api_url = f"{SUPPORT_API_BASE}/check_support" |
| payload = { |
| "context": context, |
| "subclaims": subclaims, |
| "threshold": threshold, |
| "batch_size": batch_size, |
| } |
|
|
| attempt = 0 |
| |
| |
| while True: |
| try: |
| response = requests.post(api_url, json=payload, timeout=300) |
| response.raise_for_status() |
| result = response.json() |
| return result.get("labels", ["invalid"] * len(subclaims)) |
| except requests.exceptions.RequestException as exc: |
| |
| attempt += 1 |
| if attempt > max_retries: |
| print( |
| f"Warning: Support API call failed after {max_retries} retries " |
| f"(returning None): {exc}" |
| ) |
| return None |
|
|
| |
| delay = initial_retry_delay * (backoff_factor ** (attempt - 1)) |
| print( |
| f"Warning: Support API call failed (attempt {attempt}/{max_retries}); " |
| f"retrying in {delay:.1f}s: {exc}" |
| ) |
| try: |
| time.sleep(delay) |
| except Exception: |
| |
| return None |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| MIN_SENTENCE_CHARS = 15 |
|
|
|
|
| def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]: |
| """ |
| Split text into sentences at [.!?] boundaries. |
| Segments shorter than `min_chars` characters are dropped to |
| prevent micro-fragment padding from gaming ratio-based scores. |
| """ |
| if not text or not text.strip(): |
| return [] |
| parts = re.split(r"(?<=[.!?])\s+", text.strip()) |
| return [s.strip() for s in parts if len(s.strip()) >= min_chars] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def compute_incompleteness_score( |
| summary_text: str, |
| generated_text: str, |
| threshold: float = 0.5, |
| batch_size: int = 32, |
| ) -> float: |
| """ |
| Incompleteness score in [0, 1]: fraction of summary_text sentences |
| NOT covered by generated_text. Returns None on API failure. |
| |
| Direction: summary_text sentences are the 'subclaims'; generated_text |
| is the 'context' (premise). This is the recall direction. |
| |
| API-failure handling |
| -------------------- |
| - Total failure (_call_support_api returns None) β return None. |
| The caller treats None as a null signal (no completeness component), |
| preventing a spurious zero-completeness penalty from destabilising RL. |
| - Partial failure (some labels are "invalid") β those labels are filtered |
| out; only genuinely adjudicated labels contribute to the score. |
| If ALL labels are invalid, returns None (treated as total failure). |
| """ |
| summary_sentences = _split_into_sentences(summary_text) |
| if not summary_sentences: |
| return 0.0 |
| if not generated_text or not generated_text.strip(): |
| return 1.0 |
|
|
| labels = _call_support_api( |
| context=generated_text, |
| subclaims=summary_sentences, |
| threshold=threshold, |
| batch_size=batch_size, |
| ) |
| |
|
|
| |
| if labels is None: |
| print("Warning: compute_incompleteness_score received None from API β returning None.") |
| return None |
|
|
| |
| valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] |
| if not valid_labels: |
| print("Warning: all labels were 'invalid' in compute_incompleteness_score β returning None.") |
| return None |
|
|
| not_covered = sum( |
| 1 for lbl in valid_labels |
| if str(lbl).strip().lower() != "supported" |
| ) |
| return not_covered / len(valid_labels) |
|
|
|
|
| def compute_completeness_reward( |
| summary_text: str, |
| generated_text: str, |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| ) -> float: |
| """ |
| Completeness reward in [0, 1]: fraction of summary_text sentences |
| that ARE covered by generated_text (i.e. 1 β incompleteness_score). |
| Returns None if the API failed (propagated from compute_incompleteness_score). |
| |
| This is the RECALL direction: |
| completeness_reward = covered_summary_sentences / total_summary_sentences |
| |
| A model that generates only one sentence can score at most |
| 1/N (where N = number of summary sentences), preventing reward hacking. |
| """ |
| incompleteness_score = compute_incompleteness_score( |
| summary_text=summary_text, |
| generated_text=generated_text, |
| threshold=threshold, |
| batch_size=batch_size, |
| ) |
| if incompleteness_score is None: |
| return None |
| return 1.0 - incompleteness_score |
|
|
|
|
| |
| |
| |
|
|
| def compute_hallucination_score_vs_input( |
| input_text: str, |
| generated_text: str, |
| threshold: float = 0.5, |
| batch_size: int = 128, |
| ) -> float: |
| """ |
| Hallucination score in [0, 1]: fraction of generated sentences |
| NOT supported by input_text. Returns None on API failure. |
| |
| Anti-padding design |
| ------------------- |
| 1. Minimum-length filter: segments < MIN_SENTENCE_CHARS chars are discarded. |
| 2. Fixed denominator: max(n_gen_filtered, n_input_sentences) so padding |
| safe sentences cannot dilute the hallucination ratio. |
| |
| API-failure handling |
| -------------------- |
| - Total failure (None from API) β return None. |
| The caller omits the hallucination penalty rather than applying a |
| massive spurious penalty from a transient server blip. |
| - Partial failure (some "invalid" labels) β filter them out; |
| score only the valid labels. If all labels invalid β return None. |
| """ |
| gen_segments = _split_into_sentences(generated_text) |
| if not gen_segments or not input_text or not input_text.strip(): |
| return 0.0 |
|
|
| input_sentences = _split_into_sentences(input_text) |
| stable_denom = max(len(gen_segments), len(input_sentences)) |
| if stable_denom == 0: |
| return 0.0 |
|
|
| labels = _call_support_api( |
| context=input_text, |
| subclaims=gen_segments, |
| threshold=threshold, |
| batch_size=batch_size, |
| ) |
| |
|
|
| |
| if labels is None: |
| print("Warning: compute_hallucination_score_vs_input received None from API β returning None.") |
| return None |
|
|
| |
| valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] |
| if not valid_labels: |
| print("Warning: all labels were 'invalid' in compute_hallucination_score_vs_input β returning None.") |
| return None |
|
|
| hallucinated = sum( |
| 1 for lbl in valid_labels |
| if str(lbl).strip().lower() != "supported" |
| ) |
| |
| return hallucinated / stable_denom |
|
|
|
|
| |
| |
| |
|
|
| |
| DEFAULT_API_BASE = "http://172.16.34.19:8040/v1" |
| if dspy is not None: |
| LITERACY_LM = dspy.LM( |
| model="openai/dspy", |
| api_base=os.getenv("VLLM_API_BASE", DEFAULT_API_BASE), |
| api_key="EMPTY", |
| temperature=0.0, |
| cache=False, |
| timeout=300, |
| max_tokens=None, |
| ) |
| else: |
| LITERACY_LM = None |
|
|
| MODEL_PATH = os.environ.get( |
| "HEALTH_LITERACY_MODEL_PATH", |
| "/home/mshahidul/readctrl/code/text_classifier/" |
| "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json", |
| ) |
|
|
| if dspy is not None: |
| class HealthLiteracySignature(dspy.Signature): |
| """ |
| Analyze the linguistic complexity, use of medical jargon, and sentence |
| structure of 'generated_text' to determine the health literacy level. |
| """ |
|
|
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
| literacy_label = dspy.OutputField( |
| desc=( |
| "Classification: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), or " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| _COMPILED_CLASSIFIER = None |
| _CLASSIFIER_ERROR_LOGGED = False |
|
|
|
|
| def _load_compiled_classifier(path): |
| if dspy is None: |
| raise RuntimeError("dspy is not installed") |
| if hasattr(dspy, "load"): |
| try: |
| return dspy.load(path) |
| except Exception: |
| pass |
| classifier = HealthLiteracyClassifier() |
| try: |
| classifier.load(path) |
| except Exception as exc: |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| return classifier |
|
|
|
|
| def _get_classifier(): |
| global _COMPILED_CLASSIFIER |
| if _COMPILED_CLASSIFIER is None: |
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") |
| _COMPILED_CLASSIFIER = _load_compiled_classifier(MODEL_PATH) |
| return _COMPILED_CLASSIFIER |
|
|
|
|
| def _parse_solution_json(solution_str): |
| if isinstance(solution_str, (dict, list)): |
| return solution_str |
| try: |
| cleaned_str = str(solution_str).strip() |
| if "```json" in cleaned_str: |
| cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip() |
| elif "```" in cleaned_str: |
| cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip() |
| return json.loads(cleaned_str) |
| except Exception: |
| return None |
|
|
|
|
| def _predict_label(generated_text): |
| global _CLASSIFIER_ERROR_LOGGED |
| if dspy is None: |
| print("dspy is None") |
| return "" |
| try: |
| classifier = _get_classifier() |
| if LITERACY_LM is not None: |
| with dspy.context(lm=LITERACY_LM): |
| prediction = classifier(generated_text=generated_text) |
| else: |
| prediction = classifier(generated_text=generated_text) |
| |
| except Exception as exc: |
| if not _CLASSIFIER_ERROR_LOGGED: |
| print(f"Warning: literacy classifier unavailable, continuing without it: {exc}") |
| _CLASSIFIER_ERROR_LOGGED = True |
| return "" |
|
|
| if not prediction or not hasattr(prediction, "literacy_label"): |
| prd = str(prediction) |
| if "low_health" in prd: |
| return "low_health_literacy" |
| elif "intermediate_health" in prd: |
| return "intermediate_health_literacy" |
| elif "proficient_health" in prd: |
| return "proficient_health_literacy" |
| return "" |
| return str(prediction.literacy_label).strip().lower() |
|
|
|
|
| def _compute_classifier_reward(target_level, gen_text): |
| """ |
| Soft classifier score in [0, 1] (NOT binary +1/-1). |
| |
| 1.0 β predicted label matches target level (correct style) |
| 0.0 β predicted label does not match (wrong style) |
| 0.5 β classifier unavailable; neutral / no signal |
| |
| Using a soft score instead of Β±1 prevents the classifier from |
| dominating and creating a reward cliff. |
| """ |
| result = _predict_label(gen_text) |
| if result == "": |
| return 0.5 |
| if result.strip().lower() == target_level.strip().lower(): |
| return 1.0 |
| return 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def compute_score(data_source, solution_str, ground_truth, extra_info=None): |
| """ |
| Reward = W_COMPLETENESS * completeness_reward |
| + W_CLASSIFIER * classifier_score |
| - hallucination_penalty |
| |
| Weights |
| ------- |
| W_COMPLETENESS = 0.7 (dominant: factual coverage of summary) |
| W_CLASSIFIER = 0.3 (style bonus, not a cliff) |
| |
| completeness_reward β [0, 1] β recall: fraction of summary sentences |
| covered by gen_text (vs summary_text). |
| classifier_score β [0, 1] β 1.0=correct style, 0.0=wrong, 0.5=unavailable. |
| hallucination_penalty β [0, 1] β fraction of gen sentences NOT in input_text. |
| |
| API-failure fallback |
| -------------------- |
| If both factual API calls fail (completeness=None, hallucination=None), |
| only the classifier contributes. This prevents a transient server blip |
| from injecting a large spurious penalty and destabilising PPO/GRPO. |
| |
| Range: [-1, 1] (negative only via hallucination penalty). |
| """ |
| W_COMPLETENESS = 0.7 |
| W_CLASSIFIER = 0.3 |
|
|
| |
| data = _parse_solution_json(solution_str) |
| if not data: |
| return -1.0 |
|
|
| target_level = extra_info.get("target_level") if extra_info else None |
| if not target_level: |
| return 0.0 |
|
|
| gen_text = data.get(target_level, "") |
| if not gen_text or len(gen_text.strip()) < 10: |
| return -1.0 |
|
|
| summary_text = ground_truth.get("summary_text", "") |
| input_text = ground_truth.get("input_text", "") |
|
|
| |
| completeness_reward = None |
| if summary_text and summary_text.strip(): |
| completeness_reward = compute_completeness_reward( |
| summary_text=summary_text, |
| generated_text=gen_text, |
| threshold=0.5, |
| batch_size=128, |
| ) |
| |
| if completeness_reward is None: |
| print("Warning: completeness_reward is None (API failure) β omitting from reward.") |
|
|
| |
| classifier_score = _compute_classifier_reward(target_level, gen_text) |
|
|
| |
| hallucination_penalty = None |
| if input_text and input_text.strip(): |
| hallucination_score = compute_hallucination_score_vs_input( |
| input_text=input_text, |
| generated_text=gen_text, |
| threshold=0.5, |
| batch_size=128, |
| ) |
| if hallucination_score is None: |
| print("Warning: hallucination_score is None (API failure) β omitting penalty.") |
| elif hallucination_score > 0.1: |
| hallucination_penalty = hallucination_score |
|
|
| |
| if completeness_reward is not None: |
| base_reward = W_COMPLETENESS * completeness_reward + W_CLASSIFIER * classifier_score |
| else: |
| |
| base_reward = W_CLASSIFIER * classifier_score |
|
|
| penalty = hallucination_penalty if hallucination_penalty is not None else 0.0 |
| return base_reward - penalty |
|
|
|
|
| |
| |
| |
|
|
| test_mode = True |
| if test_mode: |
| import time |
|
|
| def run_actual_api_test(): |
| |
| ground_truth = { |
| "summary_text": ( |
| "Lisinopril is used to treat high blood pressure. " |
| "It is an ACE inhibitor that helps your heart work better. " |
| "Common side effects include a dry cough. " |
| "Do not use if you are pregnant." |
| ), |
| "fulltext_subclaims": [ |
| "Lisinopril is used to treat high blood pressure.", |
| "It belongs to a class of drugs called ACE inhibitors.", |
| "Common side effects include a dry cough.", |
| "It helps prevent heart attacks and strokes.", |
| "Patients should have their kidney function monitored.", |
| "Do not use if you are pregnant.", |
| ], |
| "input_text": ( |
| "Lisinopril is used to treat high blood pressure. " |
| "It is a type of drug called an ACE inhibitor. " |
| "It helps your heart work better." |
| ), |
| } |
|
|
| |
| generated_response = { |
| "low_health_literacy": ( |
| "This medicine is for your high blood pressure. " |
| "It is a type of drug called an ACE inhibitor. " |
| "It helps your heart work better. " |
| "Do not take it if you are pregnant." |
| ) |
| } |
|
|
| solution_str = f"```json\n{json.dumps(generated_response)}\n```" |
| extra_info = {"target_level": "low_health_literacy"} |
|
|
| print("π‘ Running summary-text hallucination check test...") |
| start_time = time.time() |
|
|
| try: |
| score = compute_score( |
| data_source="real_api_test", |
| solution_str=solution_str, |
| ground_truth=ground_truth, |
| extra_info=extra_info, |
| ) |
|
|
| duration = time.time() - start_time |
| print(f"\nβ
API Call Successful ({round(duration, 2)}s)") |
| print("-" * 40) |
| print(f"Target Level : {extra_info['target_level']}") |
| print(f"Final Reward : {round(score, 4)}") |
| print("-" * 40) |
| print("\nDEBUG INFO:") |
| print("- completeness_reward : fraction of gen sentences grounded in summary_text.") |
| print("- classifier_reward : +1 if literacy label matches target, -1 otherwise.") |
| print("- hallucination_penalty : fraction of gen sentences NOT in input_text (subtracted).") |
| print("- Final = (completeness_reward + classifier_reward) / 2.0 - hallucination_penalty") |
|
|
| except Exception as e: |
| print(f"\nβ API Call Failed!") |
| print(f"Error Type: {type(e).__name__}") |
| print(f"Details: {str(e)}") |
| print("\nPossible fixes:") |
| print("1. Check if the vLLM server at :8090 is running.") |
| print("2. Verify SUPPORT_API_BASE env var is set correctly.") |
|
|
| if __name__ == "__main__": |
| run_actual_api_test() |