| | import json |
| | import re |
| | import concurrent.futures |
| | from openai import OpenAI |
| |
|
| | class MedicalClaimVerifier: |
| | def __init__(self): |
| | |
| | api_file = "/home/mshahidul/api_new.json" |
| | with open(api_file, "r") as f: |
| | api_keys = json.load(f) |
| | self.api_key = api_keys["openai"] |
| | self.model_name = "gpt-5-mini" |
| | self.client = OpenAI(api_key=self.api_key) |
| |
|
| | |
| | |
| | self.threshold_ranges = { |
| | "low": {"comp": (0.9600, 1.0000), "cov": (0.1765, 0.3226)}, |
| | "intermediate": {"comp": (0.9393, 1.0000), "cov": (0.1818, 0.4091)}, |
| | "proficient": {"comp": (0.9231, 1.0000), "cov": (0.7725, 0.9347)}, |
| | } |
| |
|
| | |
| | self.thresholds = { |
| | "low": {"comp": 1.0, "cov": 0.3226}, |
| | "intermediate": {"comp": 1.0, "cov": 0.4091}, |
| | "proficient": {"comp": 1.0, "cov": 0.9347}, |
| | } |
| |
|
| | def get_prompt(self,context,claim): |
| | prompt = f""" |
| | CONTEXT: |
| | {context} |
| | |
| | CLAIM TO VERIFY: |
| | {claim} |
| | |
| | INSTRUCTION: |
| | Does the CONTEXT above provide enough evidence to support the CLAIM? |
| | - Answer 'supported' if the claim is explicitly stated or logically followable. |
| | - Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info. |
| | |
| | Output only one word: 'supported' or 'not_supported'. |
| | """ |
| | return prompt |
| |
|
| | def check_support_api(self, prompt): |
| | try: |
| | response = self.client.chat.completions.create( |
| | model=self.model_name, |
| | messages=[{"role": "user", "content": prompt}], |
| | ) |
| | res = response.choices[0].message.content.strip().lower() |
| | |
| | return 1.0 if "supported" in res and "not_supported" not in res else 0.0 |
| | except Exception as e: |
| | print(f"API call error: {e}") |
| | return 0.0 |
| |
|
| | def evaluate_level(self, gen_text, gold_subs, full_subs, level_key): |
| | """Calculates scores for a single literacy level.""" |
| | if not gen_text: return 0.0, 0.0 |
| |
|
| | |
| | try: |
| | with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
| | |
| | comp_prompts = [self.get_prompt(gen_text, s) for s in gold_subs] |
| | comp_results = list(executor.map(self.check_support_api, comp_prompts)) |
| | comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0 |
| |
|
| |
|
| | |
| | cov_prompts = [self.get_prompt(gen_text, s) for s in full_subs] |
| | cov_results = list(executor.map(self.check_support_api, cov_prompts)) |
| | cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0 |
| | |
| | except Exception as e: |
| | print(f"Parallel API call error: {e}") |
| | return 0.0, 0.0 |
| | |
| | return comp_score, cov_score |
| |
|
| | import json |
| |
|
| | def get_reward_score(self, completion, gold_subs, full_subs): |
| | data = None |
| | |
| | |
| | try: |
| | |
| | text = completion[0]['content'].strip().replace("```json", "").replace("```", "").strip() |
| | data = json.loads(text) |
| | except (json.JSONDecodeError, IndexError, ValueError) as e: |
| | print("JSON Parsing Error in Reward Calculation") |
| | |
| | return -5.0 |
| |
|
| | |
| | levels = ["low", "intermediate", "proficient"] |
| | |
| | if not all(f"{lvl}_health_literacy" in data for lvl in levels): |
| | return -2.0 |
| |
|
| | |
| | try: |
| | total_reward = 0.0 |
| | pass_reward = 1.0 |
| | fail_penalty = -1.0 |
| | for lvl in levels: |
| | gen_text = data.get(f"{lvl}_health_literacy", "") |
| | |
| | |
| | if not gen_text: |
| | total_reward += fail_penalty |
| | continue |
| | |
| | comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, lvl) |
| | |
| | |
| | total_reward += pass_reward if comp_score >= self.thresholds[lvl]["comp"] else fail_penalty |
| | total_reward += pass_reward if cov_score >= self.thresholds[lvl]["cov"] else fail_penalty |
| | |
| | return total_reward |
| | except Exception: |
| | return -5.0 |
| |
|
| |
|
| | |
| | gold_summary_subclaims = [ |
| | "Hypertension is defined as blood pressure above 140/90 mmHg.", |
| | "Lifestyle changes like low salt intake can reduce blood pressure.", |
| | "Diuretics are often the first line of pharmacological treatment." |
| | ] |
| |
|
| | full_text_subclaims = [ |
| | "Hypertension is defined as blood pressure above 140/90 mmHg.", |
| | "Lifestyle changes like low salt intake can reduce blood pressure.", |
| | "Diuretics are often the first line of pharmacological treatment.", |
| | "The DASH diet emphasizes fruits, vegetables, and low-fat dairy.", |
| | "Chronic hypertension increases the risk of stroke and myocardial infarction.", |
| | "ACE inhibitors are contraindicated during pregnancy.", |
| | "Secondary hypertension can be caused by renal artery stenosis." |
| | ] |
| |
|
| | |
| | |
| | mock_completion = [{ |
| | 'content': """ |
| | { |
| | "low_health_literacy": "High blood pressure is when your blood is too strong for your veins. You should eat less salt to help stay healthy.", |
| | "intermediate_health_literacy": "Hypertension is blood pressure over 140/90. You can lower it by eating less salt and taking water pills (diuretics) if your doctor says so.", |
| | "proficient_health_literacy": "Hypertension (BP > 140/90 mmHg) is managed via lifestyle modifications like the DASH diet and salt restriction. Pharmacological interventions include diuretics as first-line therapy, though risks like stroke or heart attack persist if untreated. Secondary causes like renal artery stenosis should be screened, and ACE inhibitors must be avoided in pregnancy." |
| | } |
| | """ |
| | }] |
| |
|
| | |
| | verifier = MedicalClaimVerifier() |
| |
|
| | |
| | reward = verifier.get_reward_score( |
| | completion=mock_completion, |
| | gold_subs=gold_summary_subclaims, |
| | full_subs=full_text_subclaims |
| | ) |
| |
|
| | print(f"--- Evaluation Result ---") |
| | print(f"Total Reward Score: {reward}") |
| |
|
| | |
| | |
| | |
| | |