readCtrl_lambda / code /subclaim_support_extraction /inference_extract_subclaims_v4.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
# Set GPU environment variables
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from unsloth import FastLanguageModel
import json
import tqdm
import argparse
# -----------------------------
# MODEL CACHE
# -----------------------------
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=8192,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
prompt = f"""
You are an expert medical annotator. Your task is to extract granular, factual subclaims from medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.
Instructions:
1. Read the provided medical text.
2. Break it into clear, objective, atomic subclaims.
3. Each subclaim must come directly from the text.
4. Return ONLY a valid JSON list of strings.
Medical Text:
{medical_text}
Return your output in JSON list format:
[
"subclaim 1",
"subclaim 2"
]
"""
return prompt
# -----------------------------
# INFERENCE FUNCTION WITH AUTO-RETRY
# -----------------------------
def infer_subclaims(medical_text: str, model, tokenizer, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list:
if not medical_text or medical_text.strip() == "":
return []
prompt = extraction_prompt(medical_text)
messages = [{"role": "user", "content": prompt}]
chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=False
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
# Remove reasoning if model is a "Thinker" model
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
# JSON Parsing Logic
try:
start_idx = output_text.find('[')
end_idx = output_text.rfind(']') + 1
# Check if we have a complete bracketed pair
if start_idx != -1 and end_idx > start_idx:
content = output_text[start_idx:end_idx]
parsed = json.loads(content)
if isinstance(parsed, list):
return parsed
# If we are here, it means parsing failed or brackets were incomplete (truncation)
raise ValueError("Incomplete JSON list")
except (json.JSONDecodeError, ValueError):
# If truncation happened and we have retries left, double the tokens
if retries > 0:
new_max = max_tokens + 2048 # Increment by 2k tokens
print(f"\n[Warning] Truncation detected. Retrying with {new_max} tokens...")
return infer_subclaims(medical_text, model, tokenizer, temperature, max_tokens=new_max, retries=retries-1)
# Final fallback: return the raw text wrapped in a list so the pipeline doesn't crash
return [output_text]
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
args = parser.parse_args()
INPUT_FILE = args.input_file
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-extraction-8b_ctx"
os.makedirs(SAVE_FOLDER, exist_ok=True)
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}.json")
model, tokenizer = load_finetuned_model(MODEL_PATH)
with open(INPUT_FILE, "r") as f:
data = json.load(f)
result = []
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
processed_data = {str(item.get("index") or item.get("id")): item for item in result}
for item in tqdm.tqdm(data):
item_id = str(item.get("index") if item.get("index") is not None else item.get("id"))
existing_entry = processed_data.get(item_id)
# 1. Process Fulltext (The longest field, high initial token count)
if not existing_entry or not isinstance(existing_entry.get("fulltext_subclaims"), list):
f_sub = infer_subclaims(item.get("fulltext", ""), model, tokenizer, max_tokens=3072, retries=2)
else:
f_sub = existing_entry["fulltext_subclaims"]
# 2. Process Summary
if not existing_entry or not isinstance(existing_entry.get("summary_subclaims"), list):
s_sub = infer_subclaims(item.get("summary", ""), model, tokenizer, max_tokens=2048, retries=1)
else:
s_sub = existing_entry["summary_subclaims"]
# 3. Process All Generated Texts (diff_label_texts)
diff_label_texts = item.get("diff_label_texts", {})
diff_label_subclaims = existing_entry.get("diff_label_subclaims", {}) if existing_entry else {}
for label, text in diff_label_texts.items():
if label not in diff_label_subclaims or not isinstance(diff_label_subclaims[label], list):
# Generated texts are shorter, but we still allow 1 retry
diff_label_subclaims[label] = infer_subclaims(text, model, tokenizer, max_tokens=1536, retries=1)
# 4. Save
new_entry = {
"index": item.get("index"),
"id": item.get("id"),
"fulltext": item.get("fulltext", ""),
"fulltext_subclaims": f_sub,
"summary": item.get("summary", ""),
"summary_subclaims": s_sub,
"diff_label_texts": diff_label_texts,
"diff_label_subclaims": diff_label_subclaims,
"readability_score": item.get("readability_score", None)
}
processed_data[item_id] = new_entry
if len(processed_data) % 10 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
print(f"Extraction completed. File saved at: {OUTPUT_FILE}")