import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "4" import torch from unsloth import FastLanguageModel import json # =========================== # GPU SETTINGS # =========================== # =========================== # MODEL LOADING (CACHED) # =========================== _model_cache = {"model": None, "tokenizer": None} def load_finetuned_model(model_path: str): """Load and cache the fine-tuned model + tokenizer.""" if _model_cache["model"] is not None: return _model_cache["model"], _model_cache["tokenizer"] model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=4096, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer return model, tokenizer # =========================== # INFERENCE FUNCTION # =========================== def infer_reasonableness( reference_summary: str, generated_summary: str, readability_level: str, subclaim_text: str, result: int, model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3", ): """ Given the reference summary, generated summary, readability level, subclaim, and its result (0/1), predict reasonableness: reasonable / partially_reasonable / unreasonable, plus justification. """ model, tokenizer = load_finetuned_model(model_path) # ---- Build inference prompt (same structure as training) ---- prompt = f""" You are an impartial medical summarization evaluator. Goal: Decide whether the inclusion or omission of ONE specific subclaim from the reference summary is *reasonable*, given the readability level of the generated summary. Readability Criteria: - Easy: for non-medical readers; emphasize main story and outcomes; omit numerical data, anatomy, and test details. - Intermediate: for general educated readers; keep main findings but simplify phrasing. - Hard: for clinical or technical readers; maintain diagnostic accuracy and essential quantitative or anatomic content. Judging rules: * Base your decision strictly on what appears in the generated summary. * If result = 0 (subclaim omitted) and the omitted detail is clearly technical or numerical for the given level, choose "reasonable". * If result = 0 and the subclaim is essential to the main story, choose "unreasonable". * Stay consistent between `result`, justification, and readability level. ### Inputs Readability Level: {readability_level} Reference Summary: {reference_summary} Generated Summary: {generated_summary} Subclaim: "{subclaim_text}" Result: {result} # 1 = supported (included), 0 = omitted ### Task Respond **only** with the following JSON object: {{ "reasonableness": "", "justification": "" }} """.strip() messages = [{"role": "user", "content": prompt + "\n"}] chat_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, # important for Unsloth chat template ) inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") # ---- Generate output ---- with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=150, temperature=0.2, top_p=0.8, top_k=5, do_sample=False, ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() output_text = output_text.split("")[1].strip().replace("```json", "").replace("```", "") # ---- Extract model JSON output ---- try: parsed = json.loads(output_text) except Exception: # print("Failed to parse JSON from model output. Returning raw text.\n\n") parsed = output_text return parsed # =========================== # EXAMPLE USAGE # =========================== if __name__ == "__main__": # reference_summary = "Una niña nacida a las 34 semanas de gestación precisó intubación..." # generated_summary = "Esta es la historia de una niña que nació antes de tiempo, a las 34 semanas..." # subclaim_text = "La paciente presentaba hiperinsulinismo en el período neonatal." # readability_level = "easy" # result = 0 # omitted import json with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_es.json', 'r') as f: multiclinsum_gs_train_es_data = json.load(f) ref_summaries={} fulltexts={} for item in multiclinsum_gs_train_es_data: ref_summaries[item['id']]=item['summary'] fulltexts[item['id']]=item['fulltext'] generated_summaries = {} with open('/home/mshahidul/readctrl/data/hand_create_gpt5_other_model/synthetic_data_es_raw_592.json', 'r') as f: synthetic_data_es_raw_592 = json.load(f) for item in synthetic_data_es_raw_592: for version in ['easy', 'intermediate', 'hard']: generated_summaries[(item['id'], version)] = item['readability_versions'][version]['text'] # /home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json with open("/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json", 'r') as f: qwen3_32B_results = json.load(f) full_res = [] save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/completeness_resonability_check_100_qwen3-32B_v4.json" import tqdm for idx, item in tqdm.tqdm(enumerate(qwen3_32B_results)): print(f"Processing item {idx + 1}/{len(qwen3_32B_results)}") reference_summary = ref_summaries[item['id']] fulltext = fulltexts[item['id']] generated_summary = generated_summaries[(item['id'], item['version'])] temp_res = [] for item2 in item['completeness']['results']: subclaim_text = item2['subclaim']['subclaim'] result = item2['result'] if result =="1": continue response = infer_reasonableness( reference_summary, generated_summary, item['version'], subclaim_text, result, model_path="/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3", ) temp_res.append({ 'id':item2['subclaim']['id'], "subclaim": subclaim_text, "result": result, "reasonableness": response }) full_res.append({ "id": item['id'], "version": item['version'], "completeness": { "results": temp_res } }) if len(full_res)%10==0: with open(save_path, 'w') as f: json.dump(full_res, f, indent=2, ensure_ascii=False) with open(save_path, 'w') as f: json.dump(full_res, f, indent=2, ensure_ascii=False)