readctrl / code /finetune-inference /old /extracting_subclaims.py
shahidul034's picture
Add files using upload-large-folder tool
9c6961c verified
import os
import json
import tqdm
import argparse
from openai import OpenAI
# -----------------------------
# CONFIGURATION
# -----------------------------
MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims_BF16_merged"
API_URL = "http://localhost:8015/v1"
API_KEY = "EMPTY"
client = OpenAI(base_url=API_URL, api_key=API_KEY)
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
return f"""
You are an expert medical annotator. Extract granular, factual subclaims.
A subclaim is the smallest standalone factual unit that can be independently verified.
Rules:
- Use only information explicitly present in the text.
- Do not infer or hallucinate.
- Subclaims must be atomic and factual.
- Return ONLY a JSON list of strings.
Medical Text:
{medical_text}
Return output as:
[
"subclaim 1",
"subclaim 2",
...
]
"""
# -----------------------------
# INFERENCE FUNCTION
# -----------------------------
def infer_subclaims(medical_text: str, temperature: float = 0.2) -> list:
if not medical_text or medical_text.strip() == "":
return []
final_prompt = extraction_prompt(medical_text)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": final_prompt}],
max_tokens=1000,
temperature=temperature,
top_p=0.9,
)
res = response.choices[0].message.content.strip()
res = res.split("</think>")[-1].strip()
# try parse JSON
try:
return json.loads(res)
except:
return res
except Exception as e:
print(f"API error: {e}")
return []
# -----------------------------
# MAIN
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--file1", type=str, required=True,
help="Path to synthetic_data_es_raw_592.json")
parser.add_argument("--file2", type=str, required=True,
help="Path to multiclinsum_gs_train_es.json")
parser.add_argument("--start_index", type=int, default=0,
help="Start index for processing")
parser.add_argument("--end_index", type=int, default=-1,
help="End index for processing (exclusive). -1 = until end")
args = parser.parse_args()
FILE1 = args.file1
FILE2 = args.file2
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
os.makedirs(SAVE_FOLDER, exist_ok=True)
# Output filename includes the range
OUTPUT_FILE = os.path.join(
SAVE_FOLDER,
f"extracted_subclaims_{args.start_index}_{args.end_index}.json"
)
# -----------------------------
# Load files
# -----------------------------
print("Loading input files...")
with open(FILE1, "r") as f:
file1_data = {x["id"]: x for x in json.load(f)}
with open(FILE2, "r") as f:
file2_data = {x["id"]: x for x in json.load(f)}
# -----------------------------
# Merge and slice by range
# -----------------------------
all_ids = sorted(list(set(file1_data.keys()) | set(file2_data.keys())))
total_items = len(all_ids)
start = args.start_index
end = args.end_index if args.end_index != -1 else total_items
slice_ids = all_ids[start:end]
print(f"Total IDs: {total_items}")
print(f"Processing range: {start}{end} (count={len(slice_ids)})")
# -----------------------------
# Resume mode
# -----------------------------
result = []
if os.path.exists(OUTPUT_FILE):
try:
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
except:
result = []
existing_ids = {r["id"] for r in result}
# -----------------------------
# Process items
# -----------------------------
for _id in tqdm.tqdm(slice_ids):
if _id in existing_ids:
continue
# FILE1 text
easy_text = inter_text = hard_text = ""
if _id in file1_data:
rv = file1_data[_id]["readability_versions"]
easy_text = rv.get("easy", {}).get("text", "")
inter_text = rv.get("intermediate", {}).get("text", "")
hard_text = rv.get("hard", {}).get("text", "")
# FILE2 text
fulltext = summary = ""
if _id in file2_data:
fulltext = file2_data[_id].get("fulltext", "")
summary = file2_data[_id].get("summary", "")
# inference
easy_sub = infer_subclaims(easy_text)
inter_sub = infer_subclaims(inter_text)
hard_sub = infer_subclaims(hard_text)
fulltext_sub = infer_subclaims(fulltext)
summary_sub = infer_subclaims(summary)
# append
result.append({
"id": _id,
"easy_text": easy_text,
"easy_subclaims": easy_sub,
"intermediate_text": inter_text,
"intermediate_subclaims": inter_sub,
"hard_text": hard_text,
"hard_subclaims": hard_sub,
"fulltext": fulltext,
"fulltext_subclaims": fulltext_sub,
"summary": summary,
"summary_subclaims": summary_sub
})
# save frequently
if len(result) % 20 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
# final save
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"Done! Saved to: {OUTPUT_FILE}")