shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
import json
import re
import concurrent.futures
import dspy
from openai import OpenAI
class MedicalClaimVerifier:
def __init__(self):
# Prefer local vLLM (OpenAI-compatible) server settings
self.model_name = os.getenv("VLLM_MODEL", "support_check")
base_url = os.getenv("VLLM_BASE_URL", "http://172.16.34.21:8086/v1")
api_key = os.getenv("VLLM_API_KEY", "")
if not api_key:
api_file = "/home/mshahidul/api_new.json"
try:
with open(api_file, "r") as f:
api_keys = json.load(f)
api_key = api_keys.get("openai", "")
except Exception:
api_key = "EMPTY"
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.thresholds = {
"low": {"comp": 1.0, "cov": 0.3226},
"intermediate": {"comp": 1.0, "cov": 0.4091},
"proficient": {"comp": 1.0, "cov": 0.9347},
}
def get_prompt(self,context,claim):
prompt = f"""
CONTEXT:
{context}
CLAIM TO VERIFY:
{claim}
INSTRUCTION:
Does the CONTEXT above provide enough evidence to support the CLAIM?
- Answer 'supported' if the claim is explicitly stated or logically followable.
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
Output only one word: 'supported' or 'not_supported'.
"""
return prompt
def check_support_api(self, prompt):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
)
res = response.choices[0].message.content.strip().lower()
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
except Exception:
return 0.0
def evaluate_level(self, gen_text, gold_subs, full_subs):
if not gen_text or not gold_subs or not full_subs:
return 0.0, 0.0
# Check gold and full claims separately to avoid context length issues
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
comp_results = list(
executor.map(
self.check_support_api,
[self.get_prompt(gen_text, s) for s in gold_subs],
)
)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
cov_results = list(
executor.map(
self.check_support_api,
[self.get_prompt(gen_text, s) for s in full_subs],
)
)
comp_score = sum(comp_results) / len(gold_subs)
cov_score = sum(cov_results) / len(full_subs)
return comp_score, cov_score
verifier = MedicalClaimVerifier()
LLM_CPP_API_BASE = os.environ.get("LLM_CPP_API_BASE", "http://172.16.34.21:8034/v1")
MODEL_PATH = os.environ.get(
"HEALTH_LITERACY_MODEL_PATH",
"/home/mshahidul/readctrl/code/text_classifier/"
"dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json",
)
llama_cpp_lm = dspy.LM(
model="openai/dspy",
api_base=LLM_CPP_API_BASE,
api_key="EMPTY",
temperature=0.0,
)
dspy.configure(lm=llama_cpp_lm)
class HealthLiteracySignature(dspy.Signature):
"""
Analyze the linguistic complexity, use of medical jargon, and sentence
structure of 'generated_text' to determine the health literacy level.
"""
generated_text = dspy.InputField(
desc="A version of the source text rewritten for a specific audience."
)
literacy_label = dspy.OutputField(
desc=(
"Classification: low_health_literacy (simple words, no jargon), "
"intermediate_health_literacy (moderate technicality), or "
"proficient_health_literacy (highly technical/original level)."
)
)
class HealthLiteracyClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classifier = dspy.ChainOfThought(HealthLiteracySignature)
def forward(self, generated_text):
return self.classifier(generated_text=generated_text)
_COMPILED_CLASSIFIER = None
def _load_compiled_classifier(path):
if hasattr(dspy, "load"):
try:
return dspy.load(path)
except Exception:
pass
classifier = HealthLiteracyClassifier()
try:
classifier.load(path)
except Exception as exc:
raise RuntimeError(f"Failed to load compiled model from {path}") from exc
return classifier
def _get_classifier():
global _COMPILED_CLASSIFIER
if _COMPILED_CLASSIFIER is None:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
_COMPILED_CLASSIFIER = _load_compiled_classifier(MODEL_PATH)
return _COMPILED_CLASSIFIER
def _parse_solution_json(solution_str):
try:
cleaned_str = solution_str.strip()
if "```json" in cleaned_str:
cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
elif "```" in cleaned_str:
cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
return json.loads(cleaned_str)
except Exception:
return None
def _get_target_level(extra_info):
if not extra_info:
return None
return extra_info.get("target_level")
def _predict_label(generated_text):
classifier = _get_classifier()
prediction = classifier(generated_text=generated_text)
if not prediction or not hasattr(prediction, "literacy_label"):
return ""
return str(prediction.literacy_label).strip().lower()
def _compute_classifier_reward(target_level, gen_text):
try:
pred_label = _predict_label(gen_text)
except Exception:
return 0.0
return 1.0 if target_level in pred_label else 0.0
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
gold_subs = ground_truth.get('summary_subclaims', [])
full_subs = ground_truth.get('fulltext_subclaims', [])
# import ipdb; ipdb.set_trace()
if not gold_subs or not full_subs:
return 0.0
data = _parse_solution_json(solution_str)
if not data:
return 0.0
target_level = _get_target_level(extra_info)
if not target_level:
return 0.0
level_map = {
"low_health_literacy": "low",
"intermediate_health_literacy": "intermediate",
"proficient_health_literacy": "proficient",
}
level_key = level_map.get(target_level)
if not level_key:
return 0.0
gen_text = data.get(target_level, "")
if not gen_text:
return -1.0
comp_s, cov_s = verifier.evaluate_level(gen_text, gold_subs, full_subs)
thresh = verifier.thresholds[level_key]
total_reward = 0.0
total_reward += (comp_s - thresh["comp"])
total_reward += (cov_s - thresh["cov"])
classifier_reward = _compute_classifier_reward(target_level, gen_text)
return total_reward + classifier_reward