| import os |
| import json |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
|
|
| import os |
| import json |
| import tqdm |
| import argparse |
| import torch |
| from unsloth import FastLanguageModel |
|
|
| |
| |
| |
| MODEL_PATH = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16" |
| max_seq_length = 2048 |
| dtype = None |
| load_in_4bit = True |
|
|
| |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = MODEL_PATH, |
| max_seq_length = max_seq_length, |
| dtype = dtype, |
| load_in_4bit = load_in_4bit, |
| trust_remote_code = True, |
| ) |
|
|
| |
| FastLanguageModel.for_inference(model) |
|
|
| |
| |
| |
| def inference_prompt(text, subclaim): |
| |
| return f"""You are a clinical evidence auditor. Your evaluation must be based STRICTLY and ONLY on the provided medical text. |
| |
| ### MANDATORY GROUNDING RULES: |
| 1. NO OUTSIDE KNOWLEDGE: Do not use your internal medical knowledge. Even if a subclaim is "common sense" in medicine, if it is not explicitly in the TEXT, it is 'not_supported'. |
| 2. NO LOGICAL LEAPS: Do not bridge gaps in logic. (e.g., If the text mentions "high blood sugar" but not the word "diabetes", you cannot support a claim of "diabetes"). |
| 3. EXACT NUMERICAL MATCHING: Any doses (e.g., 500mg), frequencies (e.g., twice daily), or durations (e.g., 10 days) mentioned in the subclaim must match the text perfectly. If they are missing or different in the text, label as 'not_supported'. |
| 4. DEFAULT TO NOT SUPPORTED: If the text is vague, ambiguous, or only suggests a possibility, you MUST choose 'not_supported'. |
| 5. CLOSED-WORLD REALITY: Treat the TEXT as the only information that exists in the world. |
| |
| ### Medical Text: |
| {text} |
| |
| ### Subclaim: |
| {subclaim} |
| |
| Output exactly one word ('supported' or 'not_supported') based on the strict rules above:""" |
|
|
| |
| |
| |
| def check_support(text: str, subclaim: str, error_log=None) -> str: |
| if not text or not subclaim: |
| return "not_supported" |
|
|
| prompt_content = inference_prompt(text, subclaim) |
| |
| |
| messages = [{"role": "user", "content": prompt_content}] |
| inputs = tokenizer.apply_chat_template( |
| messages, |
| tokenize = True, |
| add_generation_prompt = True, |
| return_tensors = "pt", |
| ).to("cuda") |
|
|
| try: |
| |
| outputs = model.generate( |
| input_ids = inputs, |
| max_new_tokens = 512, |
| temperature = 0.1, |
| use_cache = True, |
| ) |
| |
| |
| res = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] |
| res = res.strip().lower() |
|
|
| if "</think>" in res: |
| res = res.split("</think>")[1].strip().lower() |
|
|
| if "not_supported" in res: |
| return "not_supported" |
| elif "supported" in res: |
| return "supported" |
| elif "refuted" in res: |
| return "refuted" |
| else: |
| return "not_supported" |
|
|
| except Exception as e: |
| if error_log is not None: |
| error_details = {"subclaim": subclaim, "error_msg": str(e), "type": "LOCAL_INFERENCE_ERROR"} |
| error_log.append(error_details) |
| return "not_supported" |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_file", type=str, |
| default="/home/mshahidul/readctrl/data/model_validity_check/subclaims_support_validity_check_gt_gpt5(1-5).json") |
| parser.add_argument("--save_folder", type=str, |
| default="/home/mshahidul/readctrl/data/concise_complete_attr_testing") |
| parser.add_argument("--start_index", type=int, default=0) |
| parser.add_argument("--end_index", type=int, default=-1) |
|
|
| args = parser.parse_args() |
|
|
| INPUT_FILE = args.input_file |
| SAVE_FOLDER = args.save_folder |
| os.makedirs(SAVE_FOLDER, exist_ok=True) |
|
|
| with open(INPUT_FILE, "r") as f: |
| all_data = json.load(f) |
|
|
| total_len = len(all_data) |
| start = args.start_index |
| end = args.end_index if args.end_index != -1 else total_len |
| data_slice = all_data[start:min(end, total_len)] |
|
|
| OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"evaluated_metrics_{start}_{end}_nemotran-30B.json") |
| |
| processed_results = [] |
| if os.path.exists(OUTPUT_FILE): |
| try: |
| with open(OUTPUT_FILE, "r") as f: |
| processed_results = json.load(f) |
| except: |
| processed_results = [] |
| |
| processed_ids = {item['medical_text'] for item in processed_results} |
| global_error_log = [] |
|
|
| pbar = tqdm.tqdm(data_slice) |
| |
| for item in pbar: |
| text = item.get('full_text', '') |
| if text in processed_ids: continue |
| |
| subclaims = item.get('dat', {}).get('dat', []) |
| |
| for subclaim_obj in subclaims: |
| subclaim_text = subclaim_obj.get('subclaim', '') |
| label_gt = subclaim_obj.get('status', 'not_supported').strip().lower() |
| |
| label_gen = check_support(text, subclaim_text, error_log=global_error_log) |
| |
| correctness = (label_gen == label_gt) |
| |
| result_entry = { |
| "medical_text": text, |
| "subclaim": subclaim_text, |
| "label_gt": label_gt, |
| "label_gen": label_gen, |
| "correctness": correctness |
| } |
| processed_results.append(result_entry) |
| |
| |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(processed_results, f, indent=2, ensure_ascii=False) |
|
|
| |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(processed_results, f, indent=2, ensure_ascii=False) |