File size: 6,601 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import os
import json
import tqdm
import argparse
import torch
from unsloth import FastLanguageModel

# -----------------------------
#  UNSLOTH MODEL CONFIGURATION
# -----------------------------
MODEL_PATH = "/home/mshahidul/readctrl_model/full_model/nemotron-3-nano-30b-a3b_subclaims-support-check-8b_ctx_v2-bf16"
max_seq_length = 2048 # Adjusted for medical text + reasoning context
dtype = None # Auto-detection for A100 (will likely use bfloat16)
load_in_4bit = True # To fit 32B model comfortably on A100

# Load model and tokenizer natively
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    trust_remote_code = True,
)

# Enable 2x faster native inference
FastLanguageModel.for_inference(model)

# -----------------------------
#  VERIFICATION PROMPT
# -----------------------------
def inference_prompt(text, subclaim):
    # This remains the same as your clinical evidence auditor prompt
    return f"""You are a clinical evidence auditor. Your evaluation must be based STRICTLY and ONLY on the provided medical text. 

### MANDATORY GROUNDING RULES:
1. NO OUTSIDE KNOWLEDGE: Do not use your internal medical knowledge. Even if a subclaim is "common sense" in medicine, if it is not explicitly in the TEXT, it is 'not_supported'.
2. NO LOGICAL LEAPS: Do not bridge gaps in logic. (e.g., If the text mentions "high blood sugar" but not the word "diabetes", you cannot support a claim of "diabetes").
3. EXACT NUMERICAL MATCHING: Any doses (e.g., 500mg), frequencies (e.g., twice daily), or durations (e.g., 10 days) mentioned in the subclaim must match the text perfectly. If they are missing or different in the text, label as 'not_supported'.
4. DEFAULT TO NOT SUPPORTED: If the text is vague, ambiguous, or only suggests a possibility, you MUST choose 'not_supported'. 
5. CLOSED-WORLD REALITY: Treat the TEXT as the only information that exists in the world.

### Medical Text:
{text}

### Subclaim:
{subclaim}

Output exactly one word ('supported' or 'not_supported') based on the strict rules above:"""

# -----------------------------
#  VERIFICATION LOGIC (UNSLOTH VERSION)
# -----------------------------
def check_support(text: str, subclaim: str, error_log=None) -> str:
    if not text or not subclaim:
        return "not_supported"

    prompt_content = inference_prompt(text, subclaim)
    
    # Format for Chat Template (assuming Qwen3 uses IM_START/IM_END)
    messages = [{"role": "user", "content": prompt_content}]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True,
        return_tensors = "pt",
    ).to("cuda")

    try:
        # Inference using the same parameters as your API call
        outputs = model.generate(
            input_ids = inputs,
            max_new_tokens = 512, # Kept from your max_tokens=512
            temperature = 0.1,     # Kept from your temperature=0.1
            use_cache = True,
        )
        
        # Extract response and handle thinking tokens if present
        res = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
        res = res.strip().lower()

        if "</think>" in res:
            res = res.split("</think>")[1].strip().lower()

        if "not_supported" in res:
            return "not_supported"
        elif "supported" in res:
            return "supported"
        elif "refuted" in res:
            return "refuted"
        else:
            return "not_supported"

    except Exception as e:
        if error_log is not None:
            error_details = {"subclaim": subclaim, "error_msg": str(e), "type": "LOCAL_INFERENCE_ERROR"}
            error_log.append(error_details)
        return "not_supported"

# -----------------------------
#  MAIN (Processing logic remains largely identical)
# -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, 
                        default="/home/mshahidul/readctrl/data/model_validity_check/subclaims_support_validity_check_gt_gpt5(1-5).json")
    parser.add_argument("--save_folder", type=str, 
                        default="/home/mshahidul/readctrl/data/concise_complete_attr_testing")
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--end_index", type=int, default=-1)

    args = parser.parse_args()

    INPUT_FILE = args.input_file
    SAVE_FOLDER = args.save_folder
    os.makedirs(SAVE_FOLDER, exist_ok=True)

    with open(INPUT_FILE, "r") as f:
        all_data = json.load(f)

    total_len = len(all_data)
    start = args.start_index
    end = args.end_index if args.end_index != -1 else total_len
    data_slice = all_data[start:min(end, total_len)]

    OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"evaluated_metrics_{start}_{end}_nemotran-30B.json")
   
    processed_results = []
    if os.path.exists(OUTPUT_FILE):
        try:
            with open(OUTPUT_FILE, "r") as f:
                processed_results = json.load(f)
        except:
            processed_results = []
    
    processed_ids = {item['medical_text'] for item in processed_results}
    global_error_log = []

    pbar = tqdm.tqdm(data_slice)
    
    for item in pbar:
        text = item.get('full_text', '')
        if text in processed_ids: continue # Simple skip logic for resume
            
        subclaims = item.get('dat', {}).get('dat', [])
        
        for subclaim_obj in subclaims:
            subclaim_text = subclaim_obj.get('subclaim', '')
            label_gt = subclaim_obj.get('status', 'not_supported').strip().lower()
            
            label_gen = check_support(text, subclaim_text, error_log=global_error_log)
            
            correctness = (label_gen == label_gt)
            
            result_entry = {
                "medical_text": text,
                "subclaim": subclaim_text,
                "label_gt": label_gt,
                "label_gen": label_gen,
                "correctness": correctness
            }
            processed_results.append(result_entry)
            
        # Intermediate Save
        with open(OUTPUT_FILE, "w") as f:
            json.dump(processed_results, f, indent=2, ensure_ascii=False)

    # Final Save
    with open(OUTPUT_FILE, "w") as f:
        json.dump(processed_results, f, indent=2, ensure_ascii=False)