readCtrl_lambda / code /finetune-inference /old /inference_extract_subclaims_v3.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
# Set GPU environment variables
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from unsloth import FastLanguageModel
import json
import tqdm
import argparse
# -----------------------------
# MODEL CACHE
# -----------------------------
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=8192,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
prompt = f"""
You are an expert medical annotator. Your task is to extract granular, factual subclaims from medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.
Instructions:
1. Read the provided medical text.
2. Break it into clear, objective, atomic subclaims.
3. Each subclaim must come directly from the text.
4. Return ONLY a valid JSON list of strings.
Medical Text:
{medical_text}
Return your output in JSON list format:
[
"subclaim 1",
"subclaim 2"
]
"""
return prompt
# -----------------------------
# INFERENCE FUNCTION WITH REPAIR
# -----------------------------
def infer_subclaims(medical_text: str, model, tokenizer, temperature: float = 0.2, max_tokens: int = 2048) -> list:
if not medical_text or medical_text.strip() == "":
return []
prompt = extraction_prompt(medical_text)
messages = [{"role": "user", "content": prompt}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens, # Increased default
temperature=temperature,
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
# Remove reasoning/thinking if present
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
# Attempt to parse
try:
start_idx = output_text.find('[')
end_idx = output_text.rfind(']') + 1
if start_idx != -1 and end_idx != -1:
parsed = json.loads(output_text[start_idx:end_idx])
if isinstance(parsed, list):
return parsed
return [output_text] # Wrap in list if it's just raw text
except Exception:
return [output_text]
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True)
args = parser.parse_args()
INPUT_FILE = args.input_file
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-extraction-8b_ctx"
os.makedirs(SAVE_FOLDER, exist_ok=True)
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}_en.json")
model, tokenizer = load_finetuned_model(MODEL_PATH)
# Load input dataset
with open(INPUT_FILE, "r") as f:
data = json.load(f)
# Load existing results
result = []
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
# Convert results to a dict for easy lookup/update
processed_data = {item["id"]: item for item in result}
for item in tqdm.tqdm(data):
item_id = item.get("id")
existing_entry = processed_data.get(item_id)
# CHECK LOGIC:
# If entry exists and subclaims are already valid lists, we skip.
# If they are strings or missing, we re-run with higher tokens.
# 1. Check Fulltext Subclaims
fulltext_needs_update = (
not existing_entry or
not isinstance(existing_entry.get("fulltext_subclaims"), list) or
len(existing_entry.get("fulltext_subclaims", [])) == 0
)
if fulltext_needs_update:
f_sub = infer_subclaims(item.get("fulltext", ""), model, tokenizer, max_tokens=3072)
else:
f_sub = existing_entry["fulltext_subclaims"]
# 2. Check Summary Subclaims
summary_needs_update = (
not existing_entry or
not isinstance(existing_entry.get("summary_subclaims"), list) or
len(existing_entry.get("summary_subclaims", [])) == 0
)
if summary_needs_update:
s_sub = infer_subclaims(item.get("summary", ""), model, tokenizer, max_tokens=2048)
else:
s_sub = existing_entry["summary_subclaims"]
# Update or Append
new_entry = {
"id": item_id,
"fulltext": item.get("fulltext", ""),
"fulltext_subclaims": f_sub,
"summary": item.get("summary", ""),
"summary_subclaims": s_sub,
"readability_score": item.get("readability_score", None)
}
processed_data[item_id] = new_entry
# Intermediate save
if len(processed_data) % 20 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
# Final save
with open(OUTPUT_FILE, "w") as f:
json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False)
print(f"Refinement completed. Total records: {len(processed_data)}")