import os import json os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" import os import json import tqdm import argparse import torch from unsloth import FastLanguageModel # ----------------------------- # UNSLOTH MODEL CONFIGURATION # ----------------------------- MODEL_PATH = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16" max_seq_length = 2048 # Adjusted for medical text + reasoning context dtype = None # Auto-detection for A100 (will likely use bfloat16) load_in_4bit = True # To fit 32B model comfortably on A100 # Load model and tokenizer natively model, tokenizer = FastLanguageModel.from_pretrained( model_name = MODEL_PATH, max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, trust_remote_code = True, ) # Enable 2x faster native inference FastLanguageModel.for_inference(model) # ----------------------------- # VERIFICATION PROMPT # ----------------------------- def inference_prompt(text, subclaim): # This remains the same as your clinical evidence auditor prompt return f"""You are a clinical evidence auditor. Your evaluation must be based STRICTLY and ONLY on the provided medical text. ### MANDATORY GROUNDING RULES: 1. NO OUTSIDE KNOWLEDGE: Do not use your internal medical knowledge. Even if a subclaim is "common sense" in medicine, if it is not explicitly in the TEXT, it is 'not_supported'. 2. NO LOGICAL LEAPS: Do not bridge gaps in logic. (e.g., If the text mentions "high blood sugar" but not the word "diabetes", you cannot support a claim of "diabetes"). 3. EXACT NUMERICAL MATCHING: Any doses (e.g., 500mg), frequencies (e.g., twice daily), or durations (e.g., 10 days) mentioned in the subclaim must match the text perfectly. If they are missing or different in the text, label as 'not_supported'. 4. DEFAULT TO NOT SUPPORTED: If the text is vague, ambiguous, or only suggests a possibility, you MUST choose 'not_supported'. 5. CLOSED-WORLD REALITY: Treat the TEXT as the only information that exists in the world. ### Medical Text: {text} ### Subclaim: {subclaim} Output exactly one word ('supported' or 'not_supported') based on the strict rules above:""" # ----------------------------- # VERIFICATION LOGIC (UNSLOTH VERSION) # ----------------------------- def check_support(text: str, subclaim: str, error_log=None) -> str: if not text or not subclaim: return "not_supported" prompt_content = inference_prompt(text, subclaim) # Format for Chat Template (assuming Qwen3 uses IM_START/IM_END) messages = [{"role": "user", "content": prompt_content}] inputs = tokenizer.apply_chat_template( messages, tokenize = True, add_generation_prompt = True, return_tensors = "pt", ).to("cuda") try: # Inference using the same parameters as your API call outputs = model.generate( input_ids = inputs, max_new_tokens = 512, # Kept from your max_tokens=512 temperature = 0.1, # Kept from your temperature=0.1 use_cache = True, ) # Extract response and handle thinking tokens if present res = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] res = res.strip().lower() if "" in res: res = res.split("")[1].strip().lower() if "not_supported" in res: return "not_supported" elif "supported" in res: return "supported" elif "refuted" in res: return "refuted" else: return "not_supported" except Exception as e: if error_log is not None: error_details = {"subclaim": subclaim, "error_msg": str(e), "type": "LOCAL_INFERENCE_ERROR"} error_log.append(error_details) return "not_supported" # ----------------------------- # MAIN (Processing logic remains largely identical) # ----------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, default="/home/mshahidul/readctrl/data/model_validity_check/subclaims_support_validity_check_gt_gpt5(1-5).json") parser.add_argument("--save_folder", type=str, default="/home/mshahidul/readctrl/data/concise_complete_attr_testing") parser.add_argument("--start_index", type=int, default=0) parser.add_argument("--end_index", type=int, default=-1) args = parser.parse_args() INPUT_FILE = args.input_file SAVE_FOLDER = args.save_folder os.makedirs(SAVE_FOLDER, exist_ok=True) with open(INPUT_FILE, "r") as f: all_data = json.load(f) total_len = len(all_data) start = args.start_index end = args.end_index if args.end_index != -1 else total_len data_slice = all_data[start:min(end, total_len)] OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"evaluated_metrics_{start}_{end}_nemotran-30B.json") processed_results = [] if os.path.exists(OUTPUT_FILE): try: with open(OUTPUT_FILE, "r") as f: processed_results = json.load(f) except: processed_results = [] processed_ids = {item['medical_text'] for item in processed_results} global_error_log = [] pbar = tqdm.tqdm(data_slice) for item in pbar: text = item.get('full_text', '') if text in processed_ids: continue # Simple skip logic for resume subclaims = item.get('dat', {}).get('dat', []) for subclaim_obj in subclaims: subclaim_text = subclaim_obj.get('subclaim', '') label_gt = subclaim_obj.get('status', 'not_supported').strip().lower() label_gen = check_support(text, subclaim_text, error_log=global_error_log) correctness = (label_gen == label_gt) result_entry = { "medical_text": text, "subclaim": subclaim_text, "label_gt": label_gt, "label_gen": label_gen, "correctness": correctness } processed_results.append(result_entry) # Intermediate Save with open(OUTPUT_FILE, "w") as f: json.dump(processed_results, f, indent=2, ensure_ascii=False) # Final Save with open(OUTPUT_FILE, "w") as f: json.dump(processed_results, f, indent=2, ensure_ascii=False)