readCtrl_lambda / code /finetune-inference /old /nemotran_inference.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import os
import json
import tqdm
import argparse
import torch
from unsloth import FastLanguageModel
# -----------------------------
# UNSLOTH MODEL CONFIGURATION
# -----------------------------
MODEL_PATH = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16"
max_seq_length = 2048 # Adjusted for medical text + reasoning context
dtype = None # Auto-detection for A100 (will likely use bfloat16)
load_in_4bit = True # To fit 32B model comfortably on A100
# Load model and tokenizer natively
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_PATH,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
trust_remote_code = True,
)
# Enable 2x faster native inference
FastLanguageModel.for_inference(model)
# -----------------------------
# VERIFICATION PROMPT
# -----------------------------
def inference_prompt(text, subclaim):
# This remains the same as your clinical evidence auditor prompt
return f"""You are a clinical evidence auditor. Your evaluation must be based STRICTLY and ONLY on the provided medical text.
### MANDATORY GROUNDING RULES:
1. NO OUTSIDE KNOWLEDGE: Do not use your internal medical knowledge. Even if a subclaim is "common sense" in medicine, if it is not explicitly in the TEXT, it is 'not_supported'.
2. NO LOGICAL LEAPS: Do not bridge gaps in logic. (e.g., If the text mentions "high blood sugar" but not the word "diabetes", you cannot support a claim of "diabetes").
3. EXACT NUMERICAL MATCHING: Any doses (e.g., 500mg), frequencies (e.g., twice daily), or durations (e.g., 10 days) mentioned in the subclaim must match the text perfectly. If they are missing or different in the text, label as 'not_supported'.
4. DEFAULT TO NOT SUPPORTED: If the text is vague, ambiguous, or only suggests a possibility, you MUST choose 'not_supported'.
5. CLOSED-WORLD REALITY: Treat the TEXT as the only information that exists in the world.
### Medical Text:
{text}
### Subclaim:
{subclaim}
Output exactly one word ('supported' or 'not_supported') based on the strict rules above:"""
# -----------------------------
# VERIFICATION LOGIC (UNSLOTH VERSION)
# -----------------------------
def check_support(text: str, subclaim: str, error_log=None) -> str:
if not text or not subclaim:
return "not_supported"
prompt_content = inference_prompt(text, subclaim)
# Format for Chat Template (assuming Qwen3 uses IM_START/IM_END)
messages = [{"role": "user", "content": prompt_content}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize = True,
add_generation_prompt = True,
return_tensors = "pt",
).to("cuda")
try:
# Inference using the same parameters as your API call
outputs = model.generate(
input_ids = inputs,
max_new_tokens = 512, # Kept from your max_tokens=512
temperature = 0.1, # Kept from your temperature=0.1
use_cache = True,
)
# Extract response and handle thinking tokens if present
res = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
res = res.strip().lower()
if "</think>" in res:
res = res.split("</think>")[1].strip().lower()
if "not_supported" in res:
return "not_supported"
elif "supported" in res:
return "supported"
elif "refuted" in res:
return "refuted"
else:
return "not_supported"
except Exception as e:
if error_log is not None:
error_details = {"subclaim": subclaim, "error_msg": str(e), "type": "LOCAL_INFERENCE_ERROR"}
error_log.append(error_details)
return "not_supported"
# -----------------------------
# MAIN (Processing logic remains largely identical)
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
default="/home/mshahidul/readctrl/data/model_validity_check/subclaims_support_validity_check_gt_gpt5(1-5).json")
parser.add_argument("--save_folder", type=str,
default="/home/mshahidul/readctrl/data/concise_complete_attr_testing")
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=-1)
args = parser.parse_args()
INPUT_FILE = args.input_file
SAVE_FOLDER = args.save_folder
os.makedirs(SAVE_FOLDER, exist_ok=True)
with open(INPUT_FILE, "r") as f:
all_data = json.load(f)
total_len = len(all_data)
start = args.start_index
end = args.end_index if args.end_index != -1 else total_len
data_slice = all_data[start:min(end, total_len)]
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"evaluated_metrics_{start}_{end}_nemotran-30B.json")
processed_results = []
if os.path.exists(OUTPUT_FILE):
try:
with open(OUTPUT_FILE, "r") as f:
processed_results = json.load(f)
except:
processed_results = []
processed_ids = {item['medical_text'] for item in processed_results}
global_error_log = []
pbar = tqdm.tqdm(data_slice)
for item in pbar:
text = item.get('full_text', '')
if text in processed_ids: continue # Simple skip logic for resume
subclaims = item.get('dat', {}).get('dat', [])
for subclaim_obj in subclaims:
subclaim_text = subclaim_obj.get('subclaim', '')
label_gt = subclaim_obj.get('status', 'not_supported').strip().lower()
label_gen = check_support(text, subclaim_text, error_log=global_error_log)
correctness = (label_gen == label_gt)
result_entry = {
"medical_text": text,
"subclaim": subclaim_text,
"label_gt": label_gt,
"label_gen": label_gen,
"correctness": correctness
}
processed_results.append(result_entry)
# Intermediate Save
with open(OUTPUT_FILE, "w") as f:
json.dump(processed_results, f, indent=2, ensure_ascii=False)
# Final Save
with open(OUTPUT_FILE, "w") as f:
json.dump(processed_results, f, indent=2, ensure_ascii=False)