import os import json import tqdm import argparse from openai import OpenAI # ----------------------------- # CONFIGURATION # ----------------------------- MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims_BF16_merged" API_URL = "http://localhost:8015/v1" API_KEY = "EMPTY" client = OpenAI(base_url=API_URL, api_key=API_KEY) # ----------------------------- # SUBCLAIM EXTRACTION PROMPT # ----------------------------- def extraction_prompt(medical_text: str) -> str: return f""" You are an expert medical annotator. Extract granular, factual subclaims. A subclaim is the smallest standalone factual unit that can be independently verified. Rules: - Use only information explicitly present in the text. - Do not infer or hallucinate. - Subclaims must be atomic and factual. - Return ONLY a JSON list of strings. Medical Text: {medical_text} Return output as: [ "subclaim 1", "subclaim 2", ... ] """ # ----------------------------- # INFERENCE FUNCTION # ----------------------------- def infer_subclaims(medical_text: str, temperature: float = 0.2) -> list: if not medical_text or medical_text.strip() == "": return [] final_prompt = extraction_prompt(medical_text) try: response = client.chat.completions.create( model=MODEL_NAME, messages=[{"role": "user", "content": final_prompt}], max_tokens=1000, temperature=temperature, top_p=0.9, ) res = response.choices[0].message.content.strip() res = res.split("")[-1].strip() # try parse JSON try: return json.loads(res) except: return res except Exception as e: print(f"API error: {e}") return [] # ----------------------------- # MAIN # ----------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--file1", type=str, required=True, help="Path to synthetic_data_es_raw_592.json") parser.add_argument("--file2", type=str, required=True, help="Path to multiclinsum_gs_train_es.json") parser.add_argument("--start_index", type=int, default=0, help="Start index for processing") parser.add_argument("--end_index", type=int, default=-1, help="End index for processing (exclusive). -1 = until end") args = parser.parse_args() FILE1 = args.file1 FILE2 = args.file2 SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim" os.makedirs(SAVE_FOLDER, exist_ok=True) # Output filename includes the range OUTPUT_FILE = os.path.join( SAVE_FOLDER, f"extracted_subclaims_{args.start_index}_{args.end_index}.json" ) # ----------------------------- # Load files # ----------------------------- print("Loading input files...") with open(FILE1, "r") as f: file1_data = {x["id"]: x for x in json.load(f)} with open(FILE2, "r") as f: file2_data = {x["id"]: x for x in json.load(f)} # ----------------------------- # Merge and slice by range # ----------------------------- all_ids = sorted(list(set(file1_data.keys()) | set(file2_data.keys()))) total_items = len(all_ids) start = args.start_index end = args.end_index if args.end_index != -1 else total_items slice_ids = all_ids[start:end] print(f"Total IDs: {total_items}") print(f"Processing range: {start} → {end} (count={len(slice_ids)})") # ----------------------------- # Resume mode # ----------------------------- result = [] if os.path.exists(OUTPUT_FILE): try: with open(OUTPUT_FILE, "r") as f: result = json.load(f) except: result = [] existing_ids = {r["id"] for r in result} # ----------------------------- # Process items # ----------------------------- for _id in tqdm.tqdm(slice_ids): if _id in existing_ids: continue # FILE1 text easy_text = inter_text = hard_text = "" if _id in file1_data: rv = file1_data[_id]["readability_versions"] easy_text = rv.get("easy", {}).get("text", "") inter_text = rv.get("intermediate", {}).get("text", "") hard_text = rv.get("hard", {}).get("text", "") # FILE2 text fulltext = summary = "" if _id in file2_data: fulltext = file2_data[_id].get("fulltext", "") summary = file2_data[_id].get("summary", "") # inference easy_sub = infer_subclaims(easy_text) inter_sub = infer_subclaims(inter_text) hard_sub = infer_subclaims(hard_text) fulltext_sub = infer_subclaims(fulltext) summary_sub = infer_subclaims(summary) # append result.append({ "id": _id, "easy_text": easy_text, "easy_subclaims": easy_sub, "intermediate_text": inter_text, "intermediate_subclaims": inter_sub, "hard_text": hard_text, "hard_subclaims": hard_sub, "fulltext": fulltext, "fulltext_subclaims": fulltext_sub, "summary": summary, "summary_subclaims": summary_sub }) # save frequently if len(result) % 20 == 0: with open(OUTPUT_FILE, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) # final save with open(OUTPUT_FILE, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) print(f"Done! Saved to: {OUTPUT_FILE}")