import os # Set GPU environment variables os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" import torch from unsloth import FastLanguageModel import json import tqdm import argparse # ----------------------------- # MODEL CACHE # ----------------------------- _model_cache = {"model": None, "tokenizer": None} def load_finetuned_model(model_path: str): if _model_cache["model"] is not None: return _model_cache["model"], _model_cache["tokenizer"] model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=8192, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer return model, tokenizer # ----------------------------- # SUBCLAIM EXTRACTION PROMPT # ----------------------------- def extraction_prompt(medical_text: str) -> str: prompt = f""" You are an expert medical annotator. Your task is to extract granular, factual subclaims from medical text. A subclaim is the smallest standalone factual unit that can be independently verified. Instructions: 1. Read the provided medical text. 2. Break it into clear, objective, atomic subclaims. 3. Each subclaim must come directly from the text. 4. Return ONLY a valid JSON list of strings. Medical Text: {medical_text} Return your output in JSON list format: [ "subclaim 1", "subclaim 2" ] """ return prompt # ----------------------------- # INFERENCE FUNCTION WITH AUTO-RETRY # ----------------------------- def infer_subclaims(medical_text: str, model, tokenizer, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list: if not medical_text or medical_text.strip() == "": return [] prompt = extraction_prompt(medical_text) messages = [{"role": "user", "content": prompt}] chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=False ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() # Remove reasoning if model is a "Thinker" model if "" in output_text: output_text = output_text.split("")[-1].strip() # JSON Parsing Logic try: start_idx = output_text.find('[') end_idx = output_text.rfind(']') + 1 # Check if we have a complete bracketed pair if start_idx != -1 and end_idx > start_idx: content = output_text[start_idx:end_idx] parsed = json.loads(content) if isinstance(parsed, list): return parsed # If we are here, it means parsing failed or brackets were incomplete (truncation) raise ValueError("Incomplete JSON list") except (json.JSONDecodeError, ValueError): # If truncation happened and we have retries left, double the tokens if retries > 0: new_max = max_tokens + 2048 # Increment by 2k tokens print(f"\n[Warning] Truncation detected. Retrying with {new_max} tokens...") return infer_subclaims(medical_text, model, tokenizer, temperature, max_tokens=new_max, retries=retries-1) # Final fallback: return the raw text wrapped in a list so the pipeline doesn't crash return [output_text] # ----------------------------- # MAIN EXECUTION # ----------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, required=True) args = parser.parse_args() INPUT_FILE = args.input_file file_name = os.path.basename(INPUT_FILE).split(".json")[0] SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim" MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-extraction-8b_ctx" os.makedirs(SAVE_FOLDER, exist_ok=True) OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}.json") model, tokenizer = load_finetuned_model(MODEL_PATH) with open(INPUT_FILE, "r") as f: data = json.load(f) result = [] if os.path.exists(OUTPUT_FILE): with open(OUTPUT_FILE, "r") as f: result = json.load(f) processed_data = {str(item.get("index") or item.get("id")): item for item in result} for item in tqdm.tqdm(data): item_id = str(item.get("index") if item.get("index") is not None else item.get("id")) existing_entry = processed_data.get(item_id) # 1. Process Fulltext (The longest field, high initial token count) if not existing_entry or not isinstance(existing_entry.get("fulltext_subclaims"), list): f_sub = infer_subclaims(item.get("fulltext", ""), model, tokenizer, max_tokens=3072, retries=2) else: f_sub = existing_entry["fulltext_subclaims"] # 2. Process Summary if not existing_entry or not isinstance(existing_entry.get("summary_subclaims"), list): s_sub = infer_subclaims(item.get("summary", ""), model, tokenizer, max_tokens=2048, retries=1) else: s_sub = existing_entry["summary_subclaims"] # 3. Process All Generated Texts (diff_label_texts) diff_label_texts = item.get("diff_label_texts", {}) diff_label_subclaims = existing_entry.get("diff_label_subclaims", {}) if existing_entry else {} for label, text in diff_label_texts.items(): if label not in diff_label_subclaims or not isinstance(diff_label_subclaims[label], list): # Generated texts are shorter, but we still allow 1 retry diff_label_subclaims[label] = infer_subclaims(text, model, tokenizer, max_tokens=1536, retries=1) # 4. Save new_entry = { "index": item.get("index"), "id": item.get("id"), "fulltext": item.get("fulltext", ""), "fulltext_subclaims": f_sub, "summary": item.get("summary", ""), "summary_subclaims": s_sub, "diff_label_texts": diff_label_texts, "diff_label_subclaims": diff_label_subclaims, "readability_score": item.get("readability_score", None) } processed_data[item_id] = new_entry if len(processed_data) % 10 == 0: with open(OUTPUT_FILE, "w") as f: json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) with open(OUTPUT_FILE, "w") as f: json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) print(f"Extraction completed. File saved at: {OUTPUT_FILE}")