| import os |
| import json |
| import re |
| import concurrent.futures |
| import dspy |
| from openai import OpenAI |
|
|
| class MedicalClaimVerifier: |
| def __init__(self): |
| |
| self.model_name = os.getenv("VLLM_MODEL", "support_check") |
| base_url = os.getenv("VLLM_BASE_URL", "http://172.16.34.21:8086/v1") |
| api_key = os.getenv("VLLM_API_KEY", "") |
| if not api_key: |
| api_file = "/home/mshahidul/api_new.json" |
| try: |
| with open(api_file, "r") as f: |
| api_keys = json.load(f) |
| api_key = api_keys.get("openai", "") |
| except Exception: |
| api_key = "EMPTY" |
| self.client = OpenAI(api_key=api_key, base_url=base_url) |
|
|
| self.thresholds = { |
| "low": {"comp": 1.0, "cov": 0.3226}, |
| "intermediate": {"comp": 1.0, "cov": 0.4091}, |
| "proficient": {"comp": 1.0, "cov": 0.9347}, |
| } |
|
|
| def get_prompt(self,context,claim): |
| prompt = f""" |
| CONTEXT: |
| {context} |
| |
| CLAIM TO VERIFY: |
| {claim} |
| |
| INSTRUCTION: |
| Does the CONTEXT above provide enough evidence to support the CLAIM? |
| - Answer 'supported' if the claim is explicitly stated or logically followable. |
| - Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info. |
| |
| Output only one word: 'supported' or 'not_supported'. |
| """ |
| return prompt |
|
|
| def check_support_api(self, prompt): |
| try: |
| response = self.client.chat.completions.create( |
| model=self.model_name, |
| messages=[{"role": "user", "content": prompt}], |
| ) |
| res = response.choices[0].message.content.strip().lower() |
| return 1.0 if "supported" in res and "not_supported" not in res else 0.0 |
| except Exception: |
| return 0.0 |
|
|
| def evaluate_level(self, gen_text, gold_subs, full_subs): |
| if not gen_text or not gold_subs or not full_subs: |
| return 0.0, 0.0 |
| |
| |
| with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
| comp_results = list( |
| executor.map( |
| self.check_support_api, |
| [self.get_prompt(gen_text, s) for s in gold_subs], |
| ) |
| ) |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
| cov_results = list( |
| executor.map( |
| self.check_support_api, |
| [self.get_prompt(gen_text, s) for s in full_subs], |
| ) |
| ) |
| |
| comp_score = sum(comp_results) / len(gold_subs) |
| cov_score = sum(cov_results) / len(full_subs) |
| return comp_score, cov_score |
|
|
| verifier = MedicalClaimVerifier() |
|
|
| LLM_CPP_API_BASE = os.environ.get("LLM_CPP_API_BASE", "http://172.16.34.21:8034/v1") |
| MODEL_PATH = os.environ.get( |
| "HEALTH_LITERACY_MODEL_PATH", |
| "/home/mshahidul/readctrl/code/text_classifier/" |
| "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json", |
| ) |
|
|
| llama_cpp_lm = dspy.LM( |
| model="openai/dspy", |
| api_base=LLM_CPP_API_BASE, |
| api_key="EMPTY", |
| temperature=0.0, |
| ) |
| dspy.configure(lm=llama_cpp_lm) |
|
|
|
|
| class HealthLiteracySignature(dspy.Signature): |
| """ |
| Analyze the linguistic complexity, use of medical jargon, and sentence |
| structure of 'generated_text' to determine the health literacy level. |
| """ |
|
|
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
| literacy_label = dspy.OutputField( |
| desc=( |
| "Classification: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), or " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| _COMPILED_CLASSIFIER = None |
|
|
|
|
| def _load_compiled_classifier(path): |
| if hasattr(dspy, "load"): |
| try: |
| return dspy.load(path) |
| except Exception: |
| pass |
| classifier = HealthLiteracyClassifier() |
| try: |
| classifier.load(path) |
| except Exception as exc: |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| return classifier |
|
|
|
|
| def _get_classifier(): |
| global _COMPILED_CLASSIFIER |
| if _COMPILED_CLASSIFIER is None: |
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") |
| _COMPILED_CLASSIFIER = _load_compiled_classifier(MODEL_PATH) |
| return _COMPILED_CLASSIFIER |
|
|
| def _parse_solution_json(solution_str): |
| try: |
| cleaned_str = solution_str.strip() |
| if "```json" in cleaned_str: |
| cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip() |
| elif "```" in cleaned_str: |
| cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip() |
| return json.loads(cleaned_str) |
| except Exception: |
| return None |
|
|
|
|
| def _get_target_level(extra_info): |
| if not extra_info: |
| return None |
| return extra_info.get("target_level") |
|
|
|
|
| def _predict_label(generated_text): |
| classifier = _get_classifier() |
| prediction = classifier(generated_text=generated_text) |
| if not prediction or not hasattr(prediction, "literacy_label"): |
| return "" |
| return str(prediction.literacy_label).strip().lower() |
|
|
|
|
| def _compute_classifier_reward(target_level, gen_text): |
| try: |
| pred_label = _predict_label(gen_text) |
| except Exception: |
| return 0.0 |
| return 1.0 if target_level in pred_label else 0.0 |
|
|
|
|
| def compute_score(data_source, solution_str, ground_truth, extra_info=None): |
| gold_subs = ground_truth.get('summary_subclaims', []) |
| full_subs = ground_truth.get('fulltext_subclaims', []) |
| |
| |
| if not gold_subs or not full_subs: |
| return 0.0 |
|
|
| data = _parse_solution_json(solution_str) |
| if not data: |
| return 0.0 |
|
|
| target_level = _get_target_level(extra_info) |
| if not target_level: |
| return 0.0 |
|
|
| level_map = { |
| "low_health_literacy": "low", |
| "intermediate_health_literacy": "intermediate", |
| "proficient_health_literacy": "proficient", |
| } |
| level_key = level_map.get(target_level) |
| if not level_key: |
| return 0.0 |
|
|
| gen_text = data.get(target_level, "") |
| if not gen_text: |
| return -1.0 |
|
|
| comp_s, cov_s = verifier.evaluate_level(gen_text, gold_subs, full_subs) |
| thresh = verifier.thresholds[level_key] |
|
|
| total_reward = 0.0 |
| total_reward += (comp_s - thresh["comp"]) |
| total_reward += (cov_s - thresh["cov"]) |
|
|
| classifier_reward = _compute_classifier_reward(target_level, gen_text) |
| return total_reward + classifier_reward |