import os import re import json import argparse from typing import Any, List, Dict import warnings import time import requests test_mode = True warnings.filterwarnings("ignore") test_mode = False try: import dspy except ImportError: dspy = None SUPPORT_API_BASE = os.getenv("SUPPORT_API_BASE", "http://172.16.34.19:8090") # --------------------------------------------------------------------------- # Support-API helper # --------------------------------------------------------------------------- def _call_support_api( context: str, subclaims: List[str], threshold: float = 0.5, batch_size: int = 128, max_retries: int = 3, initial_retry_delay: float = 5.0, backoff_factor: float = 2.0, ) -> List[str]: """ Call the FastAPI /check_support endpoint. Returns ------- List[str] : one label per subclaim — "supported" | "not_supported" | "invalid". None : returned on a TOTAL network/transport failure, so callers can distinguish a genuine API error from a valid "not_supported" label and avoid applying a false penalty. """ if not context or not subclaims: return ["invalid"] * len(subclaims) api_url = f"{SUPPORT_API_BASE}/check_support" payload = { "context": context, "subclaims": subclaims, "threshold": threshold, "batch_size": batch_size, } attempt = 0 # We treat *any* RequestException (including HTTP 5xx) as retryable up to max_retries. # After exhausting retries, we return None so callers can skip applying penalties. while True: try: response = requests.post(api_url, json=payload, timeout=300) response.raise_for_status() result = response.json() return result.get("labels", ["invalid"] * len(subclaims)) except requests.exceptions.RequestException as exc: # import ipdb; ipdb.set_trace() attempt += 1 if attempt > max_retries: print( f"Warning: Support API call failed after {max_retries} retries " f"(returning None): {exc}" ) return None # ← None signals total failure; NOT the same as "not_supported" # Exponential backoff between retries. delay = initial_retry_delay * (backoff_factor ** (attempt - 1)) print( f"Warning: Support API call failed (attempt {attempt}/{max_retries}); " f"retrying in {delay:.1f}s: {exc}" ) try: time.sleep(delay) except Exception: # If sleep is interrupted for any reason, break early and surface failure. return None # --------------------------------------------------------------------------- # Sentence splitter # --------------------------------------------------------------------------- # Minimum character length for a sentence to be considered a real unit. # Fragments shorter than this (e.g. "Yes.", bullet stubs) are discarded # to prevent models from padding with trivially short safe sentences. MIN_SENTENCE_CHARS = 15 def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]: """ Split text into sentences at [.!?] boundaries. Segments shorter than `min_chars` characters are dropped to prevent micro-fragment padding from gaming ratio-based scores. """ if not text or not text.strip(): return [] parts = re.split(r"(?<=[.!?])\s+", text.strip()) return [s.strip() for s in parts if len(s.strip()) >= min_chars] # --------------------------------------------------------------------------- # Completeness reward (Recall direction: summary_text → generated_text) # --------------------------------------------------------------------------- # True completeness = how much of the reference (summary_text) is covered # by the generated text. This is the RECALL direction: # # For each sentence in summary_text: # Is it supported/entailed by generated_text? # completeness = covered_summary_sentences / total_summary_sentences # # This prevents reward hacking: generating a single safe sentence will no # longer score 100%; the model must cover more of the summary to score high. # --------------------------------------------------------------------------- def compute_incompleteness_score( summary_text: str, generated_text: str, threshold: float = 0.5, batch_size: int = 32, ) -> float: """ Incompleteness score in [0, 1]: fraction of summary_text sentences NOT covered by generated_text. Returns None on API failure. Direction: summary_text sentences are the 'subclaims'; generated_text is the 'context' (premise). This is the recall direction. API-failure handling -------------------- - Total failure (_call_support_api returns None) → return None. The caller treats None as a null signal (no completeness component), preventing a spurious zero-completeness penalty from destabilising RL. - Partial failure (some labels are "invalid") → those labels are filtered out; only genuinely adjudicated labels contribute to the score. If ALL labels are invalid, returns None (treated as total failure). """ summary_sentences = _split_into_sentences(summary_text) if not summary_sentences: return 0.0 if not generated_text or not generated_text.strip(): return 1.0 # Nothing generated → fully incomplete labels = _call_support_api( context=generated_text, subclaims=summary_sentences, threshold=threshold, batch_size=batch_size, ) # import ipdb; ipdb.set_trace() # Total API failure if labels is None: print("Warning: compute_incompleteness_score received None from API — returning None.") return None # Partial failure: filter out "invalid" labels; score only valid ones valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] if not valid_labels: print("Warning: all labels were 'invalid' in compute_incompleteness_score — returning None.") return None not_covered = sum( 1 for lbl in valid_labels if str(lbl).strip().lower() != "supported" ) return not_covered / len(valid_labels) def compute_completeness_reward( summary_text: str, generated_text: str, threshold: float = 0.5, batch_size: int = 128, ) -> float: """ Completeness reward in [0, 1]: fraction of summary_text sentences that ARE covered by generated_text (i.e. 1 – incompleteness_score). Returns None if the API failed (propagated from compute_incompleteness_score). This is the RECALL direction: completeness_reward = covered_summary_sentences / total_summary_sentences A model that generates only one sentence can score at most 1/N (where N = number of summary sentences), preventing reward hacking. """ incompleteness_score = compute_incompleteness_score( summary_text=summary_text, generated_text=generated_text, threshold=threshold, batch_size=batch_size, ) if incompleteness_score is None: return None # propagate API-failure signal return 1.0 - incompleteness_score # --------------------------------------------------------------------------- # Hallucination penalty: gen_text sentences vs. input_text (full source) # --------------------------------------------------------------------------- def compute_hallucination_score_vs_input( input_text: str, generated_text: str, threshold: float = 0.5, batch_size: int = 128, ) -> float: """ Hallucination score in [0, 1]: fraction of generated sentences NOT supported by input_text. Returns None on API failure. Anti-padding design ------------------- 1. Minimum-length filter: segments < MIN_SENTENCE_CHARS chars are discarded. 2. Fixed denominator: max(n_gen_filtered, n_input_sentences) so padding safe sentences cannot dilute the hallucination ratio. API-failure handling -------------------- - Total failure (None from API) → return None. The caller omits the hallucination penalty rather than applying a massive spurious penalty from a transient server blip. - Partial failure (some "invalid" labels) → filter them out; score only the valid labels. If all labels invalid → return None. """ gen_segments = _split_into_sentences(generated_text) if not gen_segments or not input_text or not input_text.strip(): return 0.0 input_sentences = _split_into_sentences(input_text) stable_denom = max(len(gen_segments), len(input_sentences)) if stable_denom == 0: return 0.0 labels = _call_support_api( context=input_text, subclaims=gen_segments, threshold=threshold, batch_size=batch_size, ) # import ipdb; ipdb.set_trace() # Total API failure if labels is None: print("Warning: compute_hallucination_score_vs_input received None from API — returning None.") return None # Partial failure: filter "invalid" labels valid_labels = [lbl for lbl in labels if str(lbl).strip().lower() != "invalid"] if not valid_labels: print("Warning: all labels were 'invalid' in compute_hallucination_score_vs_input — returning None.") return None hallucinated = sum( 1 for lbl in valid_labels if str(lbl).strip().lower() != "supported" ) # Use stable_denom to block padding inflation (not len(valid_labels)) return hallucinated / stable_denom # --------------------------------------------------------------------------- # DSPy health-literacy classifier (unchanged) # --------------------------------------------------------------------------- # DEFAULT_API_BASE = "http://172.16.34.22:8040/v1" DEFAULT_API_BASE = "http://172.16.34.19:8040/v1" if dspy is not None: LITERACY_LM = dspy.LM( model="openai/dspy", api_base=os.getenv("VLLM_API_BASE", DEFAULT_API_BASE), api_key="EMPTY", temperature=0.0, cache=False, timeout=300, max_tokens=None, ) else: LITERACY_LM = None MODEL_PATH = os.environ.get( "HEALTH_LITERACY_MODEL_PATH", "/home/mshahidul/readctrl/code/text_classifier/" "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json", ) if dspy is not None: class HealthLiteracySignature(dspy.Signature): """ Analyze the linguistic complexity, use of medical jargon, and sentence structure of 'generated_text' to determine the health literacy level. """ generated_text = dspy.InputField( desc="A version of the source text rewritten for a specific audience." ) literacy_label = dspy.OutputField( desc=( "Classification: low_health_literacy (simple words, no jargon), " "intermediate_health_literacy (moderate technicality), or " "proficient_health_literacy (highly technical/original level)." ) ) class HealthLiteracyClassifier(dspy.Module): def __init__(self): super().__init__() self.classifier = dspy.ChainOfThought(HealthLiteracySignature) def forward(self, generated_text): return self.classifier(generated_text=generated_text) _COMPILED_CLASSIFIER = None _CLASSIFIER_ERROR_LOGGED = False def _load_compiled_classifier(path): if dspy is None: raise RuntimeError("dspy is not installed") if hasattr(dspy, "load"): try: return dspy.load(path) except Exception: pass classifier = HealthLiteracyClassifier() try: classifier.load(path) except Exception as exc: raise RuntimeError(f"Failed to load compiled model from {path}") from exc return classifier def _get_classifier(): global _COMPILED_CLASSIFIER if _COMPILED_CLASSIFIER is None: if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") _COMPILED_CLASSIFIER = _load_compiled_classifier(MODEL_PATH) return _COMPILED_CLASSIFIER def _parse_solution_json(solution_str): if isinstance(solution_str, (dict, list)): return solution_str try: cleaned_str = str(solution_str).strip() if "```json" in cleaned_str: cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip() elif "```" in cleaned_str: cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip() return json.loads(cleaned_str) except Exception: return None def _predict_label(generated_text): global _CLASSIFIER_ERROR_LOGGED if dspy is None: print("dspy is None") return "" try: classifier = _get_classifier() if LITERACY_LM is not None: with dspy.context(lm=LITERACY_LM): prediction = classifier(generated_text=generated_text) else: prediction = classifier(generated_text=generated_text) # import ipdb; ipdb.set_trace() except Exception as exc: if not _CLASSIFIER_ERROR_LOGGED: print(f"Warning: literacy classifier unavailable, continuing without it: {exc}") _CLASSIFIER_ERROR_LOGGED = True return "" if not prediction or not hasattr(prediction, "literacy_label"): prd = str(prediction) if "low_health" in prd: return "low_health_literacy" elif "intermediate_health" in prd: return "intermediate_health_literacy" elif "proficient_health" in prd: return "proficient_health_literacy" return "" return str(prediction.literacy_label).strip().lower() def _compute_classifier_reward(target_level, gen_text): """ Soft classifier score in [0, 1] (NOT binary +1/-1). 1.0 — predicted label matches target level (correct style) 0.0 — predicted label does not match (wrong style) 0.5 — classifier unavailable; neutral / no signal Using a soft score instead of ±1 prevents the classifier from dominating and creating a reward cliff. """ result = _predict_label(gen_text) if result == "": # unavailable → neutral, no penalty return 0.5 if result.strip().lower() == target_level.strip().lower(): return 1.0 # correct literacy style return 0.0 # wrong literacy style (penalty-free cliff avoided) # --------------------------------------------------------------------------- # Main scoring function # --------------------------------------------------------------------------- def compute_score(data_source, solution_str, ground_truth, extra_info=None): """ Reward = W_COMPLETENESS * completeness_reward + W_CLASSIFIER * classifier_score - hallucination_penalty Weights ------- W_COMPLETENESS = 0.7 (dominant: factual coverage of summary) W_CLASSIFIER = 0.3 (style bonus, not a cliff) completeness_reward ∈ [0, 1] — recall: fraction of summary sentences covered by gen_text (vs summary_text). classifier_score ∈ [0, 1] — 1.0=correct style, 0.0=wrong, 0.5=unavailable. hallucination_penalty ∈ [0, 1] — fraction of gen sentences NOT in input_text. API-failure fallback -------------------- If both factual API calls fail (completeness=None, hallucination=None), only the classifier contributes. This prevents a transient server blip from injecting a large spurious penalty and destabilising PPO/GRPO. Range: [-1, 1] (negative only via hallucination penalty). """ W_COMPLETENESS = 0.7 W_CLASSIFIER = 0.3 # 1. Format & Data Validation data = _parse_solution_json(solution_str) if not data: return -1.0 target_level = extra_info.get("target_level") if extra_info else None if not target_level: return 0.0 gen_text = data.get(target_level, "") if not gen_text or len(gen_text.strip()) < 10: return -1.0 summary_text = ground_truth.get("summary_text", "") input_text = ground_truth.get("input_text", "") # 2. Completeness reward (recall: summary_text → gen_text) completeness_reward = None if summary_text and summary_text.strip(): completeness_reward = compute_completeness_reward( summary_text=summary_text, generated_text=gen_text, threshold=0.5, batch_size=128, ) # None = API failure → log and skip component if completeness_reward is None: print("Warning: completeness_reward is None (API failure) — omitting from reward.") # 3. Classifier score (soft bonus: 1.0 match / 0.0 mismatch / 0.5 unavailable) classifier_score = _compute_classifier_reward(target_level, gen_text) # 4. Hallucination penalty (gen_text → input_text) hallucination_penalty = None if input_text and input_text.strip(): hallucination_score = compute_hallucination_score_vs_input( input_text=input_text, generated_text=gen_text, threshold=0.5, batch_size=128, ) if hallucination_score is None: print("Warning: hallucination_score is None (API failure) — omitting penalty.") elif hallucination_score > 0.1: # ignore trivial noise hallucination_penalty = hallucination_score # 5. Final reward — gracefully degrade when API signals are missing if completeness_reward is not None: base_reward = W_COMPLETENESS * completeness_reward + W_CLASSIFIER * classifier_score else: # API failed for completeness: use classifier-only signal (small but stable) base_reward = W_CLASSIFIER * classifier_score penalty = hallucination_penalty if hallucination_penalty is not None else 0.0 return base_reward - penalty # --------------------------------------------------------------------------- # Test mode # --------------------------------------------------------------------------- test_mode = True if test_mode: import time def run_actual_api_test(): # Prepare real medical data ground_truth = { "summary_text": ( "Lisinopril is used to treat high blood pressure. " "It is an ACE inhibitor that helps your heart work better. " "Common side effects include a dry cough. " "Do not use if you are pregnant." ), "fulltext_subclaims": [ "Lisinopril is used to treat high blood pressure.", "It belongs to a class of drugs called ACE inhibitors.", "Common side effects include a dry cough.", "It helps prevent heart attacks and strokes.", "Patients should have their kidney function monitored.", "Do not use if you are pregnant.", ], "input_text": ( "Lisinopril is used to treat high blood pressure. " "It is a type of drug called an ACE inhibitor. " "It helps your heart work better." ), } # LLM output: well-grounded in summary_text generated_response = { "low_health_literacy": ( "This medicine is for your high blood pressure. " "It is a type of drug called an ACE inhibitor. " "It helps your heart work better. " "Do not take it if you are pregnant." ) } solution_str = f"```json\n{json.dumps(generated_response)}\n```" extra_info = {"target_level": "low_health_literacy"} print("📡 Running summary-text hallucination check test...") start_time = time.time() try: score = compute_score( data_source="real_api_test", solution_str=solution_str, ground_truth=ground_truth, extra_info=extra_info, ) duration = time.time() - start_time print(f"\n✅ API Call Successful ({round(duration, 2)}s)") print("-" * 40) print(f"Target Level : {extra_info['target_level']}") print(f"Final Reward : {round(score, 4)}") print("-" * 40) print("\nDEBUG INFO:") print("- completeness_reward : fraction of gen sentences grounded in summary_text.") print("- classifier_reward : +1 if literacy label matches target, -1 otherwise.") print("- hallucination_penalty : fraction of gen sentences NOT in input_text (subtracted).") print("- Final = (completeness_reward + classifier_reward) / 2.0 - hallucination_penalty") except Exception as e: print(f"\n❌ API Call Failed!") print(f"Error Type: {type(e).__name__}") print(f"Details: {str(e)}") print("\nPossible fixes:") print("1. Check if the vLLM server at :8090 is running.") print("2. Verify SUPPORT_API_BASE env var is set correctly.") if __name__ == "__main__": run_actual_api_test()