import os import json import tqdm import argparse from openai import OpenAI # ----------------------------- # API CONFIGURATION # ----------------------------- LOCAL_API_URL = "http://172.16.34.29:8004/v1" LOCAL_MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims-extraction-8b_ctx_fp16" client = OpenAI( base_url=LOCAL_API_URL, api_key="EMPTY" ) # ----------------------------- # SUBCLAIM EXTRACTION PROMPT # ----------------------------- def extraction_prompt(medical_text: str) -> str: return f""" You are an expert medical annotator. Your task is to extract granular, factual subclaims from the provided medical text. A subclaim is the smallest standalone factual unit that can be independently verified. Instructions: 1. Read the medical text carefully. 2. Extract factual statements explicitly stated in the text. 3. Each subclaim must: - Contain exactly ONE factual assertion - Come directly from the text (no inference or interpretation) - Preserve original wording as much as possible - Include any negation, uncertainty, or qualifier (e.g., "may", "not", "suggests") 4. Do NOT: - Combine multiple facts into one subclaim - Add new information - Rephrase or normalize terminology - Include opinions or recommendations 5. Return ONLY a valid JSON array of strings. 6. Use double quotes and valid JSON formatting only (no markdown, no commentary). Medical Text: {medical_text} Return format: [ "subclaim 1", "subclaim 2" ] """.strip() # ----------------------------- # INFERENCE FUNCTION (vLLM API) # ----------------------------- def infer_subclaims_api(medical_text: str, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list: if not medical_text or not medical_text.strip(): return [] prompt = extraction_prompt(medical_text) try: response = client.chat.completions.create( model=LOCAL_MODEL_NAME, messages=[{"role": "user", "content": prompt}], temperature=temperature, max_tokens=max_tokens, ) output_text = response.choices[0].message.content.strip() if "" in output_text: output_text = output_text.split("")[-1].strip() start_idx = output_text.find('[') end_idx = output_text.rfind(']') + 1 if start_idx != -1 and end_idx > start_idx: content = output_text[start_idx:end_idx] parsed = json.loads(content) if isinstance(parsed, list): return parsed raise ValueError("Incomplete JSON list") except (json.JSONDecodeError, ValueError, Exception) as e: if retries > 0: new_max = max_tokens + 2048 print(f"\n[Warning] API error/truncation: {e}. Retrying with {new_max} tokens...") return infer_subclaims_api(medical_text, temperature, max_tokens=new_max, retries=retries-1) return [output_text] if 'output_text' in locals() else [] # ----------------------------- # MAIN EXECUTION # ----------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, required=True) parser.add_argument("--start", type=int, default=0, help="Start index in the dataset") parser.add_argument("--end", type=int, default=None, help="End index (exclusive) in the dataset") args = parser.parse_args() INPUT_FILE = args.input_file file_name = os.path.basename(INPUT_FILE).split(".json")[0] SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim" os.makedirs(SAVE_FOLDER, exist_ok=True) # Range-specific output naming helps if you want to run parallel jobs range_suffix = f"_{args.start}_{args.end if args.end is not None else 'end'}" OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}{range_suffix}.json") with open(INPUT_FILE, "r") as f: full_data = json.load(f) if args.end is None: args.end = len(full_data) # Slice the data based on user input data_subset = full_data[args.start:args.end] print(f"Processing range [{args.start} : {args.end if args.end else len(full_data)}]. Total: {len(data_subset)} items.") # Load existing progress if available processed_data = {} if os.path.exists(OUTPUT_FILE): with open(OUTPUT_FILE, "r") as f: existing_list = json.load(f) processed_data = {str(item.get("id")): item for item in existing_list} for item in tqdm.tqdm(data_subset): item_id = str(item.get("id")) # Check if this item in the subset was already processed if item_id in processed_data: continue # 1. Process Fulltext f_sub = infer_subclaims_api(item.get("fulltext", ""), max_tokens=3072, retries=2) # 2. Process Summary s_sub = infer_subclaims_api(item.get("summary", ""), max_tokens=2048, retries=1) # 3. Save Entry processed_data[item_id] = { "id": item_id, "fulltext": item.get("fulltext", ""), "fulltext_subclaims": f_sub, "summary": item.get("summary", ""), "summary_subclaims": s_sub } # Periodic checkpoint if len(processed_data) % 20 == 0: with open(OUTPUT_FILE, "w") as f: json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) # Final Save with open(OUTPUT_FILE, "w") as f: json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) print(f"Range extraction completed. File saved at: {OUTPUT_FILE}")