import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" import torch from unsloth import FastLanguageModel import json import tqdm import re # ----------------------------- # MODEL CACHE # ----------------------------- _model_cache = {"model": None, "tokenizer": None} def load_finetuned_model(model_path: str): if _model_cache["model"] is not None: return _model_cache["model"], _model_cache["tokenizer"] model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=8192, load_in_4bit=False, # Set to True if you want 4bit inference for speed/memory load_in_8bit=False, full_finetuning=False, ) # Enable native 2x faster inference FastLanguageModel.for_inference(model) _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer return model, tokenizer # ----------------------------- # READABILITY CLASSIFICATION PROMPT # ----------------------------- def classification_prompt(full_text: str, summary: str) -> str: """ Constructs the prompt to classify readability of the summary based on the context of the full text. """ prompt = f"""You are a medical readability evaluator. ### Task Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience. ### Input Data - **FULL TEXT:** {full_text} - **GENERATED TEXT (Evaluate this):** {summary} ### Readability Scale 1: Very Easy - Minimal medical language, uses simple terms. 2: Easy - Accessible to most, minor jargon explained. 3: Medium - Some technical terms, moderate complexity. 4: Hard - Clinical tone, assumes some prior knowledge. 5: Very Hard - Extremely technical, requires medical expertise. ### Constraints - Evaluate ONLY the "GENERATED TEXT". - Use "FULL TEXT" only for context of the subject matter. - Do NOT assess factual accuracy. ### Output Format Return ONLY a valid JSON object: {{ "readability_score": }}""" return prompt # ----------------------------- # INFERENCE FUNCTION # ----------------------------- def infer_readability(full_text: str, summary: str, model_path: str) -> dict: model, tokenizer = load_finetuned_model(model_path) prompt = classification_prompt(full_text, summary) messages = [{"role": "user", "content": prompt}] chat_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=50, # Classification only needs a few tokens temperature=0.1, # Low temperature for classification consistency do_sample=False, ) output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip() # Clean up output (remove thinking or markdown) if "" in output_text: output_text = output_text.split("")[-1].strip() # Simple regex to extract JSON if the model adds conversational filler try: match = re.search(r"\{.*\}", output_text, re.DOTALL) if match: return json.loads(match.group()) return {"readability_score": "error", "raw": output_text} except Exception: return {"readability_score": "error", "raw": output_text} # ----------------------------- # MAIN EXECUTION # ----------------------------- if __name__ == "__main__": # Settings based on your paths INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json" SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability" # Note: Ensure this path points to your CLASSIFIER model, not the subclaim extractor MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en" os.makedirs(SAVE_FOLDER, exist_ok=True) file_name = os.path.basename(INPUT_FILE).split(".json")[0] OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json") # Load input dataset with open(INPUT_FILE, "r") as f: data = json.load(f) # Resume mode result = [] if os.path.exists(OUTPUT_FILE): with open(OUTPUT_FILE, "r") as f: result = json.load(f) existing_ids = {item["id"] for item in result} print(f"Starting classification. Saving to: {OUTPUT_FILE}") for item in tqdm.tqdm(data): if item["id"] in existing_ids: continue full_text = item.get("fulltext", "") summary = item.get("summary", "") classification_res = infer_readability( full_text=full_text, summary=summary, model_path=MODEL_PATH ) result.append({ "id": item["id"], "readability_score": classification_res.get("readability_score"), "fulltext": full_text, "summary": summary }) # Checkpoint every 50 items if len(result) % 50 == 0: with open(OUTPUT_FILE, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) # Final save with open(OUTPUT_FILE, "w") as f: json.dump(result, f, indent=4, ensure_ascii=False) print(f"Classification completed. {len(result)} items processed.")