File size: 6,601 Bytes
030876e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import os
import json
import tqdm
import argparse
import torch
from unsloth import FastLanguageModel
# -----------------------------
# UNSLOTH MODEL CONFIGURATION
# -----------------------------
MODEL_PATH = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16"
max_seq_length = 2048 # Adjusted for medical text + reasoning context
dtype = None # Auto-detection for A100 (will likely use bfloat16)
load_in_4bit = True # To fit 32B model comfortably on A100
# Load model and tokenizer natively
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_PATH,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
trust_remote_code = True,
)
# Enable 2x faster native inference
FastLanguageModel.for_inference(model)
# -----------------------------
# VERIFICATION PROMPT
# -----------------------------
def inference_prompt(text, subclaim):
# This remains the same as your clinical evidence auditor prompt
return f"""You are a clinical evidence auditor. Your evaluation must be based STRICTLY and ONLY on the provided medical text.
### MANDATORY GROUNDING RULES:
1. NO OUTSIDE KNOWLEDGE: Do not use your internal medical knowledge. Even if a subclaim is "common sense" in medicine, if it is not explicitly in the TEXT, it is 'not_supported'.
2. NO LOGICAL LEAPS: Do not bridge gaps in logic. (e.g., If the text mentions "high blood sugar" but not the word "diabetes", you cannot support a claim of "diabetes").
3. EXACT NUMERICAL MATCHING: Any doses (e.g., 500mg), frequencies (e.g., twice daily), or durations (e.g., 10 days) mentioned in the subclaim must match the text perfectly. If they are missing or different in the text, label as 'not_supported'.
4. DEFAULT TO NOT SUPPORTED: If the text is vague, ambiguous, or only suggests a possibility, you MUST choose 'not_supported'.
5. CLOSED-WORLD REALITY: Treat the TEXT as the only information that exists in the world.
### Medical Text:
{text}
### Subclaim:
{subclaim}
Output exactly one word ('supported' or 'not_supported') based on the strict rules above:"""
# -----------------------------
# VERIFICATION LOGIC (UNSLOTH VERSION)
# -----------------------------
def check_support(text: str, subclaim: str, error_log=None) -> str:
if not text or not subclaim:
return "not_supported"
prompt_content = inference_prompt(text, subclaim)
# Format for Chat Template (assuming Qwen3 uses IM_START/IM_END)
messages = [{"role": "user", "content": prompt_content}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize = True,
add_generation_prompt = True,
return_tensors = "pt",
).to("cuda")
try:
# Inference using the same parameters as your API call
outputs = model.generate(
input_ids = inputs,
max_new_tokens = 512, # Kept from your max_tokens=512
temperature = 0.1, # Kept from your temperature=0.1
use_cache = True,
)
# Extract response and handle thinking tokens if present
res = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
res = res.strip().lower()
if "</think>" in res:
res = res.split("</think>")[1].strip().lower()
if "not_supported" in res:
return "not_supported"
elif "supported" in res:
return "supported"
elif "refuted" in res:
return "refuted"
else:
return "not_supported"
except Exception as e:
if error_log is not None:
error_details = {"subclaim": subclaim, "error_msg": str(e), "type": "LOCAL_INFERENCE_ERROR"}
error_log.append(error_details)
return "not_supported"
# -----------------------------
# MAIN (Processing logic remains largely identical)
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
default="/home/mshahidul/readctrl/data/model_validity_check/subclaims_support_validity_check_gt_gpt5(1-5).json")
parser.add_argument("--save_folder", type=str,
default="/home/mshahidul/readctrl/data/concise_complete_attr_testing")
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=-1)
args = parser.parse_args()
INPUT_FILE = args.input_file
SAVE_FOLDER = args.save_folder
os.makedirs(SAVE_FOLDER, exist_ok=True)
with open(INPUT_FILE, "r") as f:
all_data = json.load(f)
total_len = len(all_data)
start = args.start_index
end = args.end_index if args.end_index != -1 else total_len
data_slice = all_data[start:min(end, total_len)]
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"evaluated_metrics_{start}_{end}_nemotran-30B.json")
processed_results = []
if os.path.exists(OUTPUT_FILE):
try:
with open(OUTPUT_FILE, "r") as f:
processed_results = json.load(f)
except:
processed_results = []
processed_ids = {item['medical_text'] for item in processed_results}
global_error_log = []
pbar = tqdm.tqdm(data_slice)
for item in pbar:
text = item.get('full_text', '')
if text in processed_ids: continue # Simple skip logic for resume
subclaims = item.get('dat', {}).get('dat', [])
for subclaim_obj in subclaims:
subclaim_text = subclaim_obj.get('subclaim', '')
label_gt = subclaim_obj.get('status', 'not_supported').strip().lower()
label_gen = check_support(text, subclaim_text, error_log=global_error_log)
correctness = (label_gen == label_gt)
result_entry = {
"medical_text": text,
"subclaim": subclaim_text,
"label_gt": label_gt,
"label_gen": label_gen,
"correctness": correctness
}
processed_results.append(result_entry)
# Intermediate Save
with open(OUTPUT_FILE, "w") as f:
json.dump(processed_results, f, indent=2, ensure_ascii=False)
# Final Save
with open(OUTPUT_FILE, "w") as f:
json.dump(processed_results, f, indent=2, ensure_ascii=False) |