readctrl / code /RL_model /unsloth_rl /claim_verifier.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import json
import re
import concurrent.futures
from openai import OpenAI
class MedicalClaimVerifier:
def __init__(self):
# OpenAI API configuration
api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
api_keys = json.load(f)
self.api_key = api_keys["openai"]
self.model_name = "gpt-5-mini"
self.client = OpenAI(api_key=self.api_key)
# Literacy ranges (IQR after outlier removal) from paper summary
# comp = completeness vs gold summary; cov = source_coverage vs full text
self.threshold_ranges = {
"low": {"comp": (0.9600, 1.0000), "cov": (0.1765, 0.3226)},
"intermediate": {"comp": (0.9393, 1.0000), "cov": (0.1818, 0.4091)},
"proficient": {"comp": (0.9231, 1.0000), "cov": (0.7725, 0.9347)},
}
# Minimum required information (upper bound of IQR)
self.thresholds = {
"low": {"comp": 1.0, "cov": 0.3226},
"intermediate": {"comp": 1.0, "cov": 0.4091},
"proficient": {"comp": 1.0, "cov": 0.9347},
}
def get_prompt(self,context,claim):
prompt = f"""
CONTEXT:
{context}
CLAIM TO VERIFY:
{claim}
INSTRUCTION:
Does the CONTEXT above provide enough evidence to support the CLAIM?
- Answer 'supported' if the claim is explicitly stated or logically followable.
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
Output only one word: 'supported' or 'not_supported'.
"""
return prompt
def check_support_api(self, prompt):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
)
res = response.choices[0].message.content.strip().lower()
# print("API Response:", res)
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
except Exception as e:
print(f"API call error: {e}")
return 0.0
def evaluate_level(self, gen_text, gold_subs, full_subs, level_key):
"""Calculates scores for a single literacy level."""
if not gen_text: return 0.0, 0.0
# Run API calls in parallel to save time during RL
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Completeness check (vs Gold Summary Subclaims)
comp_prompts = [self.get_prompt(gen_text, s) for s in gold_subs]
comp_results = list(executor.map(self.check_support_api, comp_prompts))
comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0
# Coverage check (vs Full Text Subclaims)
cov_prompts = [self.get_prompt(gen_text, s) for s in full_subs]
cov_results = list(executor.map(self.check_support_api, cov_prompts))
cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0
# print(f"Comp Score: {comp_score}, Cov Score: {cov_score} for {level_key}")
except Exception as e:
print(f"Parallel API call error: {e}")
return 0.0, 0.0
return comp_score, cov_score
import json
def get_reward_score(self, completion, gold_subs, full_subs):
data = None
# 1. Robust JSON Extraction
try:
# Clean potential markdown or whitespace
text = completion[0]['content'].strip().replace("```json", "").replace("```", "").strip()
data = json.loads(text)
except (json.JSONDecodeError, IndexError, ValueError) as e:
print("JSON Parsing Error in Reward Calculation")
# If all extraction attempts fail
return -5.0
# 2. Schema Validation
levels = ["low", "intermediate", "proficient"]
# Check if any required keys are missing
if not all(f"{lvl}_health_literacy" in data for lvl in levels):
return -2.0 # Slightly smaller penalty for partial formatting success
# 3. Scoring Logic
try:
total_reward = 0.0
pass_reward = 1.0
fail_penalty = -1.0
for lvl in levels:
gen_text = data.get(f"{lvl}_health_literacy", "")
# Skip scoring if text is empty
if not gen_text:
total_reward += fail_penalty
continue
comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, lvl)
# Apply Thresholds
total_reward += pass_reward if comp_score >= self.thresholds[lvl]["comp"] else fail_penalty
total_reward += pass_reward if cov_score >= self.thresholds[lvl]["cov"] else fail_penalty
return total_reward
except Exception:
return -5.0
# 1. Ground Truth Subclaims (Extracted from a medical paper on Hypertension)
gold_summary_subclaims = [
"Hypertension is defined as blood pressure above 140/90 mmHg.",
"Lifestyle changes like low salt intake can reduce blood pressure.",
"Diuretics are often the first line of pharmacological treatment."
]
full_text_subclaims = [
"Hypertension is defined as blood pressure above 140/90 mmHg.",
"Lifestyle changes like low salt intake can reduce blood pressure.",
"Diuretics are often the first line of pharmacological treatment.",
"The DASH diet emphasizes fruits, vegetables, and low-fat dairy.",
"Chronic hypertension increases the risk of stroke and myocardial infarction.",
"ACE inhibitors are contraindicated during pregnancy.",
"Secondary hypertension can be caused by renal artery stenosis."
]
# 2. Mock Model Completion (The output being evaluated)
# This mimics the format your RL environment would pass to the reward function
mock_completion = [{
'content': """
{
"low_health_literacy": "High blood pressure is when your blood is too strong for your veins. You should eat less salt to help stay healthy.",
"intermediate_health_literacy": "Hypertension is blood pressure over 140/90. You can lower it by eating less salt and taking water pills (diuretics) if your doctor says so.",
"proficient_health_literacy": "Hypertension (BP > 140/90 mmHg) is managed via lifestyle modifications like the DASH diet and salt restriction. Pharmacological interventions include diuretics as first-line therapy, though risks like stroke or heart attack persist if untreated. Secondary causes like renal artery stenosis should be screened, and ACE inhibitors must be avoided in pregnancy."
}
"""
}]
# Initialize your verifier
verifier = MedicalClaimVerifier()
# Test the reward calculation
reward = verifier.get_reward_score(
completion=mock_completion,
gold_subs=gold_summary_subclaims,
full_subs=full_text_subclaims
)
print(f"--- Evaluation Result ---")
print(f"Total Reward Score: {reward}")
# Logic Explanation:
# - Low: Likely fails 'comp' (missing 140/90 info), but might pass 'cov' (low threshold).
# - Intermediate: Likely passes 'comp' and 'cov'.
# - Proficient: Needs to cover almost all 7 subclaims to pass the 0.77 coverage threshold.