File size: 5,739 Bytes
030876e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import os
import json
import tqdm
import argparse
from openai import OpenAI
# -----------------------------
# API CONFIGURATION
# -----------------------------
LOCAL_API_URL = "http://172.16.34.29:8004/v1"
LOCAL_MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims-extraction-8b_ctx_fp16"
client = OpenAI(
base_url=LOCAL_API_URL,
api_key="EMPTY"
)
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
return f"""
You are an expert medical annotator.
Your task is to extract granular, factual subclaims from the provided medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.
Instructions:
1. Read the medical text carefully.
2. Extract factual statements explicitly stated in the text.
3. Each subclaim must:
- Contain exactly ONE factual assertion
- Come directly from the text (no inference or interpretation)
- Preserve original wording as much as possible
- Include any negation, uncertainty, or qualifier (e.g., "may", "not", "suggests")
4. Do NOT:
- Combine multiple facts into one subclaim
- Add new information
- Rephrase or normalize terminology
- Include opinions or recommendations
5. Return ONLY a valid JSON array of strings.
6. Use double quotes and valid JSON formatting only (no markdown, no commentary).
Medical Text:
{medical_text}
Return format:
[
"subclaim 1",
"subclaim 2"
]
""".strip()
# -----------------------------
# INFERENCE FUNCTION (vLLM API)
# -----------------------------
def infer_subclaims_api(medical_text: str, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list:
if not medical_text or not medical_text.strip():
return []
prompt = extraction_prompt(medical_text)
try:
response = client.chat.completions.create(
model=LOCAL_MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
)
output_text = response.choices[0].message.content.strip()
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
start_idx = output_text.find('[')
end_idx = output_text.rfind(']') + 1
if start_idx != -1 and end_idx > start_idx:
content = output_text[start_idx:end_idx]
parsed = json.loads(content)
if isinstance(parsed, list):
return parsed
raise ValueError("Incomplete JSON list")
except (json.JSONDecodeError, ValueError, Exception) as e:
if retries > 0:
new_max = max_tokens + 2048
print(f"\n[Warning] API error/truncation: {e}. Retrying with {new_max} tokens...")
return infer_subclaims_api(medical_text, temperature, max_tokens=new_max, retries=retries-1)
return [output_text] if 'output_text' in locals() else []
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--start", type=int, default=0, help="Start index in the dataset")
parser.add_argument("--end", type=int, default=None, help="End index (exclusive) in the dataset")
args = parser.parse_args()
INPUT_FILE = args.input_file
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
os.makedirs(SAVE_FOLDER, exist_ok=True)
# Range-specific output naming helps if you want to run parallel jobs
range_suffix = f"_{args.start}_{args.end if args.end is not None else 'end'}"
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}{range_suffix}.json")
with open(INPUT_FILE, "r") as f:
full_data = json.load(f)
if args.end is None:
args.end = len(full_data)
# Slice the data based on user input
data_subset = full_data[args.start:args.end]
print(f"Processing range [{args.start} : {args.end if args.end else len(full_data)}]. Total: {len(data_subset)} items.")
# Load existing progress if available
processed_data = {}
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
existing_list = json.load(f)
processed_data = {str(item.get("id")): item for item in existing_list}
for item in tqdm.tqdm(data_subset):
item_id = str(item.get("id"))
# Check if this item in the subset was already processed
if item_id in processed_data:
continue
# 1. Process Fulltext
f_sub = infer_subclaims_api(item.get("fulltext", ""), max_tokens=3072, retries=2)
# 2. Process Summary
s_sub = infer_subclaims_api(item.get("summary", ""), max_tokens=2048, retries=1)
# 3. Save Entry
processed_data[item_id] = {
"id": item_id,
"fulltext": item.get("fulltext", ""),
"fulltext_subclaims": f_sub,
"summary": item.get("summary", ""),
"summary_subclaims": s_sub
}
# Periodic checkpoint
if len(processed_data) % 20 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
# Final Save
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
print(f"Range extraction completed. File saved at: {OUTPUT_FILE}") |