import json import re import concurrent.futures from openai import OpenAI class MedicalClaimVerifier: def __init__(self): # OpenAI API configuration api_file = "/home/mshahidul/api_new.json" with open(api_file, "r") as f: api_keys = json.load(f) self.api_key = api_keys["openai"] self.model_name = "gpt-5-mini" self.client = OpenAI(api_key=self.api_key) # Literacy ranges (IQR after outlier removal) from paper summary # comp = completeness vs gold summary; cov = source_coverage vs full text self.threshold_ranges = { "low": {"comp": (0.9600, 1.0000), "cov": (0.1765, 0.3226)}, "intermediate": {"comp": (0.9393, 1.0000), "cov": (0.1818, 0.4091)}, "proficient": {"comp": (0.9231, 1.0000), "cov": (0.7725, 0.9347)}, } # Minimum required information (upper bound of IQR) self.thresholds = { "low": {"comp": 1.0, "cov": 0.3226}, "intermediate": {"comp": 1.0, "cov": 0.4091}, "proficient": {"comp": 1.0, "cov": 0.9347}, } def get_prompt(self,context,claim): prompt = f""" CONTEXT: {context} CLAIM TO VERIFY: {claim} INSTRUCTION: Does the CONTEXT above provide enough evidence to support the CLAIM? - Answer 'supported' if the claim is explicitly stated or logically followable. - Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info. Output only one word: 'supported' or 'not_supported'. """ return prompt def check_support_api(self, prompt): try: response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": prompt}], ) res = response.choices[0].message.content.strip().lower() # print("API Response:", res) return 1.0 if "supported" in res and "not_supported" not in res else 0.0 except Exception as e: print(f"API call error: {e}") return 0.0 def evaluate_level(self, gen_text, gold_subs, full_subs, level_key): """Calculates scores for a single literacy level.""" if not gen_text: return 0.0, 0.0 # Run API calls in parallel to save time during RL try: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # Completeness check (vs Gold Summary Subclaims) comp_prompts = [self.get_prompt(gen_text, s) for s in gold_subs] comp_results = list(executor.map(self.check_support_api, comp_prompts)) comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0 # Coverage check (vs Full Text Subclaims) cov_prompts = [self.get_prompt(gen_text, s) for s in full_subs] cov_results = list(executor.map(self.check_support_api, cov_prompts)) cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0 # print(f"Comp Score: {comp_score}, Cov Score: {cov_score} for {level_key}") except Exception as e: print(f"Parallel API call error: {e}") return 0.0, 0.0 return comp_score, cov_score import json def get_reward_score(self, completion, gold_subs, full_subs): data = None # 1. Robust JSON Extraction try: # Clean potential markdown or whitespace text = completion[0]['content'].strip().replace("```json", "").replace("```", "").strip() data = json.loads(text) except (json.JSONDecodeError, IndexError, ValueError) as e: print("JSON Parsing Error in Reward Calculation") # If all extraction attempts fail return -5.0 # 2. Schema Validation levels = ["low", "intermediate", "proficient"] # Check if any required keys are missing if not all(f"{lvl}_health_literacy" in data for lvl in levels): return -2.0 # Slightly smaller penalty for partial formatting success # 3. Scoring Logic try: total_reward = 0.0 pass_reward = 1.0 fail_penalty = -1.0 for lvl in levels: gen_text = data.get(f"{lvl}_health_literacy", "") # Skip scoring if text is empty if not gen_text: total_reward += fail_penalty continue comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, lvl) # Apply Thresholds total_reward += pass_reward if comp_score >= self.thresholds[lvl]["comp"] else fail_penalty total_reward += pass_reward if cov_score >= self.thresholds[lvl]["cov"] else fail_penalty return total_reward except Exception: return -5.0 # 1. Ground Truth Subclaims (Extracted from a medical paper on Hypertension) gold_summary_subclaims = [ "Hypertension is defined as blood pressure above 140/90 mmHg.", "Lifestyle changes like low salt intake can reduce blood pressure.", "Diuretics are often the first line of pharmacological treatment." ] full_text_subclaims = [ "Hypertension is defined as blood pressure above 140/90 mmHg.", "Lifestyle changes like low salt intake can reduce blood pressure.", "Diuretics are often the first line of pharmacological treatment.", "The DASH diet emphasizes fruits, vegetables, and low-fat dairy.", "Chronic hypertension increases the risk of stroke and myocardial infarction.", "ACE inhibitors are contraindicated during pregnancy.", "Secondary hypertension can be caused by renal artery stenosis." ] # 2. Mock Model Completion (The output being evaluated) # This mimics the format your RL environment would pass to the reward function mock_completion = [{ 'content': """ { "low_health_literacy": "High blood pressure is when your blood is too strong for your veins. You should eat less salt to help stay healthy.", "intermediate_health_literacy": "Hypertension is blood pressure over 140/90. You can lower it by eating less salt and taking water pills (diuretics) if your doctor says so.", "proficient_health_literacy": "Hypertension (BP > 140/90 mmHg) is managed via lifestyle modifications like the DASH diet and salt restriction. Pharmacological interventions include diuretics as first-line therapy, though risks like stroke or heart attack persist if untreated. Secondary causes like renal artery stenosis should be screened, and ACE inhibitors must be avoided in pregnancy." } """ }] # Initialize your verifier verifier = MedicalClaimVerifier() # Test the reward calculation reward = verifier.get_reward_score( completion=mock_completion, gold_subs=gold_summary_subclaims, full_subs=full_text_subclaims ) print(f"--- Evaluation Result ---") print(f"Total Reward Score: {reward}") # Logic Explanation: # - Low: Likely fails 'comp' (missing 140/90 info), but might pass 'cov' (low threshold). # - Intermediate: Likely passes 'comp' and 'cov'. # - Proficient: Needs to cover almost all 7 subclaims to pass the 0.77 coverage threshold.