shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
import json
import re
import concurrent.futures
from openai import OpenAI
class MedicalClaimVerifier:
def __init__(self):
# Prefer local vLLM (OpenAI-compatible) server settings
self.model_name = os.getenv("VLLM_MODEL", "support_check")
base_url = os.getenv("VLLM_BASE_URL", "http://172.16.34.21:8086/v1")
api_key = os.getenv("VLLM_API_KEY", "")
if not api_key:
api_file = "/home/mshahidul/api_new.json"
try:
with open(api_file, "r") as f:
api_keys = json.load(f)
api_key = api_keys.get("openai", "")
except Exception:
api_key = "EMPTY"
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.thresholds = {
"low": {"comp": 1.0, "cov": 0.3226},
"intermediate": {"comp": 1.0, "cov": 0.4091},
"proficient": {"comp": 1.0, "cov": 0.9347},
}
def get_prompt(self,context,claim):
prompt = f"""
CONTEXT:
{context}
CLAIM TO VERIFY:
{claim}
INSTRUCTION:
Does the CONTEXT above provide enough evidence to support the CLAIM?
- Answer 'supported' if the claim is explicitly stated or logically followable.
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
Output only one word: 'supported' or 'not_supported'.
"""
return prompt
def check_support_api(self, prompt):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
)
res = response.choices[0].message.content.strip().lower()
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
except Exception:
return 0.0
def evaluate_level(self, gen_text, gold_subs, full_subs):
if not gen_text or not gold_subs or not full_subs:
return 0.0, 0.0
# Check gold and full claims separately to avoid context length issues
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
comp_results = list(
executor.map(
self.check_support_api,
[self.get_prompt(gen_text, s) for s in gold_subs],
)
)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
cov_results = list(
executor.map(
self.check_support_api,
[self.get_prompt(gen_text, s) for s in full_subs],
)
)
comp_score = sum(comp_results) / len(gold_subs)
cov_score = sum(cov_results) / len(full_subs)
return comp_score, cov_score
verifier = MedicalClaimVerifier()
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
gold_subs = ground_truth.get('summary_subclaims', [])
full_subs = ground_truth.get('fulltext_subclaims', [])
if not gold_subs or not full_subs:
return 0.0
# 1. Parsing with fallback
try:
cleaned_str = solution_str.strip()
if "```json" in cleaned_str:
cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
elif "```" in cleaned_str:
cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
data = json.loads(cleaned_str)
except Exception:
return -5.0
levels = ["low", "intermediate", "proficient"]
scores = {}
# 2. Score Calculation
for lvl in levels:
gen_text = data.get(f"{lvl}_health_literacy", "")
if not gen_text:
scores[lvl] = {"comp": 0.0, "cov": 0.0, "missing": True}
else:
comp, cov = verifier.evaluate_level(gen_text, gold_subs, full_subs)
scores[lvl] = {"comp": comp, "cov": cov, "missing": False}
# 3. Reward Shaping Logic
total_reward = 0.0
low_cov = scores["low"]["cov"]
int_cov = scores["intermediate"]["cov"]
pro_cov = scores["proficient"]["cov"]
# Soft Hierarchy Check: Reward progression, penalize stagnation
# Instead of -2.0 exit, we subtract if the order is wrong
hierarchy_penalty = 0.0
if not (low_cov <= int_cov <= pro_cov):
hierarchy_penalty = -2.0
for lvl in levels:
if scores[lvl]["missing"]:
total_reward -= 1.0 # Penalty per missing field
continue
comp_s = scores[lvl]["comp"]
cov_s = scores[lvl]["cov"]
thresh = verifier.thresholds[lvl]
# Continuous Reward: (Actual - Threshold)
# This tells the model "You're 10% away" vs "You failed"
total_reward += (comp_s - thresh["comp"])
total_reward += (cov_s - thresh["cov"])
return total_reward + hierarchy_penalty
def run_mock_example():
# 1. Setup Ground Truth Subclaims
# Imagine a source text about "Metformin for Type 2 Diabetes"
ground_truth = {
"summary_subclaims": [
"Metformin is a first-line medication for Type 2 Diabetes.",
"Common side effects include gastrointestinal upset.",
"It helps lower blood glucose levels."
],
"fulltext_subclaims": [
"Metformin is a first-line medication for Type 2 Diabetes.",
"It works by reducing glucose production in the liver.",
"Common side effects include nausea and diarrhea.",
"Patients should take it with meals to reduce stomach issues.",
"It does not typically cause weight gain.",
"Long-term use may lead to Vitamin B12 deficiency."
]
}
# 2. Mock Generated Solution (as if it came from the LLM)
# We purposefully make 'low' very basic and 'proficient' very detailed
solution_json = {
"low_health_literacy": "Metformin is used for diabetes and helps lower blood sugar.",
"intermediate_health_literacy": "Metformin is a first-line treatment for Type 2 Diabetes. It lowers glucose and can cause stomach upset.",
"proficient_health_literacy": "Metformin is the primary treatment for Type 2 Diabetes. It reduces hepatic glucose production. Side effects include gastrointestinal issues like nausea, but taking it with food helps. It is weight-neutral and may cause B12 deficiency over time."
}
solution_str = f"```json\n{json.dumps(solution_json)}\n```"
print("--- Starting Complex Evaluation ---")
# 3. Run the Score Calculation
# Note: This will make 36 API calls (3 levels * (3 gold + 6 full subclaims))
# Ensure your vLLM server is running!
final_reward = compute_score("mock_source", solution_str, ground_truth)
print(f"\nFinal Calculated Reward: {final_reward:.4f}")
if __name__ == "__main__":
run_mock_example()