File size: 5,739 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import json
import tqdm
import argparse
from openai import OpenAI

# -----------------------------
#  API CONFIGURATION
# -----------------------------
LOCAL_API_URL = "http://172.16.34.29:8004/v1"
LOCAL_MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims-extraction-8b_ctx_fp16"

client = OpenAI(
    base_url=LOCAL_API_URL,
    api_key="EMPTY"
)

# -----------------------------
#  SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
    return f"""
You are an expert medical annotator.

Your task is to extract granular, factual subclaims from the provided medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.

Instructions:
1. Read the medical text carefully.
2. Extract factual statements explicitly stated in the text.
3. Each subclaim must:
   - Contain exactly ONE factual assertion
   - Come directly from the text (no inference or interpretation)
   - Preserve original wording as much as possible
   - Include any negation, uncertainty, or qualifier (e.g., "may", "not", "suggests")
4. Do NOT:
   - Combine multiple facts into one subclaim
   - Add new information
   - Rephrase or normalize terminology
   - Include opinions or recommendations
5. Return ONLY a valid JSON array of strings.
6. Use double quotes and valid JSON formatting only (no markdown, no commentary).

Medical Text:
{medical_text}

Return format:
[
  "subclaim 1",
  "subclaim 2"
]
""".strip()


# -----------------------------
#  INFERENCE FUNCTION (vLLM API)
# -----------------------------
def infer_subclaims_api(medical_text: str, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list:
    if not medical_text or not medical_text.strip():
        return []

    prompt = extraction_prompt(medical_text)
    
    try:
        response = client.chat.completions.create(
            model=LOCAL_MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
            max_tokens=max_tokens,
        )

        output_text = response.choices[0].message.content.strip()
        
        if "</think>" in output_text:
            output_text = output_text.split("</think>")[-1].strip()

        start_idx = output_text.find('[')
        end_idx = output_text.rfind(']') + 1
        
        if start_idx != -1 and end_idx > start_idx:
            content = output_text[start_idx:end_idx]
            parsed = json.loads(content)
            if isinstance(parsed, list):
                return parsed
        
        raise ValueError("Incomplete JSON list")

    except (json.JSONDecodeError, ValueError, Exception) as e:
        if retries > 0:
            new_max = max_tokens + 2048 
            print(f"\n[Warning] API error/truncation: {e}. Retrying with {new_max} tokens...")
            return infer_subclaims_api(medical_text, temperature, max_tokens=new_max, retries=retries-1)
        
        return [output_text] if 'output_text' in locals() else []

# -----------------------------
#  MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--start", type=int, default=0, help="Start index in the dataset")
    parser.add_argument("--end", type=int, default=None, help="End index (exclusive) in the dataset")
    args = parser.parse_args()
    
    INPUT_FILE = args.input_file
    file_name = os.path.basename(INPUT_FILE).split(".json")[0]
    SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
    os.makedirs(SAVE_FOLDER, exist_ok=True)
    
    # Range-specific output naming helps if you want to run parallel jobs
    range_suffix = f"_{args.start}_{args.end if args.end is not None else 'end'}"
    OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}{range_suffix}.json")

    with open(INPUT_FILE, "r") as f:
        full_data = json.load(f)
    
    if args.end is None:
        args.end = len(full_data)

    # Slice the data based on user input
    data_subset = full_data[args.start:args.end]
    print(f"Processing range [{args.start} : {args.end if args.end else len(full_data)}]. Total: {len(data_subset)} items.")

    # Load existing progress if available
    processed_data = {}
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r") as f:
            existing_list = json.load(f)
            processed_data = {str(item.get("id")): item for item in existing_list}

    for item in tqdm.tqdm(data_subset):
        item_id = str(item.get("id"))
        
        # Check if this item in the subset was already processed
        if item_id in processed_data:
            continue
        
        # 1. Process Fulltext
        f_sub = infer_subclaims_api(item.get("fulltext", ""), max_tokens=3072, retries=2)

        # 2. Process Summary
        s_sub = infer_subclaims_api(item.get("summary", ""), max_tokens=2048, retries=1)

        # 3. Save Entry
        processed_data[item_id] = {
            "id": item_id,
            "fulltext": item.get("fulltext", ""),
            "fulltext_subclaims": f_sub,
            "summary": item.get("summary", ""),
            "summary_subclaims": s_sub
        }

        # Periodic checkpoint
        if len(processed_data) % 20 == 0:
            with open(OUTPUT_FILE, "w") as f:
                json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)

    # Final Save
    with open(OUTPUT_FILE, "w") as f:
        json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)

    print(f"Range extraction completed. File saved at: {OUTPUT_FILE}")