readCtrl_lambda / code /subclaim_support_extraction /inference_extract_subclaims_vllm.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
import json
import tqdm
import argparse
from openai import OpenAI
# -----------------------------
# API CONFIGURATION
# -----------------------------
LOCAL_API_URL = "http://172.16.34.29:8004/v1"
LOCAL_MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims-extraction-8b_ctx_fp16"
client = OpenAI(
base_url=LOCAL_API_URL,
api_key="EMPTY"
)
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
return f"""
You are an expert medical annotator.
Your task is to extract granular, factual subclaims from the provided medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.
Instructions:
1. Read the medical text carefully.
2. Extract factual statements explicitly stated in the text.
3. Each subclaim must:
- Contain exactly ONE factual assertion
- Come directly from the text (no inference or interpretation)
- Preserve original wording as much as possible
- Include any negation, uncertainty, or qualifier (e.g., "may", "not", "suggests")
4. Do NOT:
- Combine multiple facts into one subclaim
- Add new information
- Rephrase or normalize terminology
- Include opinions or recommendations
5. Return ONLY a valid JSON array of strings.
6. Use double quotes and valid JSON formatting only (no markdown, no commentary).
Medical Text:
{medical_text}
Return format:
[
"subclaim 1",
"subclaim 2"
]
""".strip()
# -----------------------------
# INFERENCE FUNCTION (vLLM API)
# -----------------------------
def infer_subclaims_api(medical_text: str, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list:
if not medical_text or not medical_text.strip():
return []
prompt = extraction_prompt(medical_text)
try:
response = client.chat.completions.create(
model=LOCAL_MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
)
output_text = response.choices[0].message.content.strip()
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
start_idx = output_text.find('[')
end_idx = output_text.rfind(']') + 1
if start_idx != -1 and end_idx > start_idx:
content = output_text[start_idx:end_idx]
parsed = json.loads(content)
if isinstance(parsed, list):
return parsed
raise ValueError("Incomplete JSON list")
except (json.JSONDecodeError, ValueError, Exception) as e:
if retries > 0:
new_max = max_tokens + 2048
print(f"\n[Warning] API error/truncation: {e}. Retrying with {new_max} tokens...")
return infer_subclaims_api(medical_text, temperature, max_tokens=new_max, retries=retries-1)
return [output_text] if 'output_text' in locals() else []
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--start", type=int, default=0, help="Start index in the dataset")
parser.add_argument("--end", type=int, default=None, help="End index (exclusive) in the dataset")
args = parser.parse_args()
INPUT_FILE = args.input_file
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
os.makedirs(SAVE_FOLDER, exist_ok=True)
# Range-specific output naming helps if you want to run parallel jobs
range_suffix = f"_{args.start}_{args.end if args.end is not None else 'end'}"
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}{range_suffix}.json")
with open(INPUT_FILE, "r") as f:
full_data = json.load(f)
if args.end is None:
args.end = len(full_data)
# Slice the data based on user input
data_subset = full_data[args.start:args.end]
print(f"Processing range [{args.start} : {args.end if args.end else len(full_data)}]. Total: {len(data_subset)} items.")
# Load existing progress if available
processed_data = {}
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
existing_list = json.load(f)
processed_data = {str(item.get("id")): item for item in existing_list}
for item in tqdm.tqdm(data_subset):
item_id = str(item.get("id"))
# Check if this item in the subset was already processed
if item_id in processed_data:
continue
# 1. Process Fulltext
f_sub = infer_subclaims_api(item.get("fulltext", ""), max_tokens=3072, retries=2)
# 2. Process Summary
s_sub = infer_subclaims_api(item.get("summary", ""), max_tokens=2048, retries=1)
# 3. Save Entry
processed_data[item_id] = {
"id": item_id,
"fulltext": item.get("fulltext", ""),
"fulltext_subclaims": f_sub,
"summary": item.get("summary", ""),
"summary_subclaims": s_sub
}
# Periodic checkpoint
if len(processed_data) % 20 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
# Final Save
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
print(f"Range extraction completed. File saved at: {OUTPUT_FILE}")