| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "4" |
| | import torch |
| | from unsloth import FastLanguageModel |
| | import json |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | _model_cache = {"model": None, "tokenizer": None} |
| |
|
| | def load_finetuned_model(model_path: str): |
| | """Load and cache the fine-tuned model + tokenizer.""" |
| | if _model_cache["model"] is not None: |
| | return _model_cache["model"], _model_cache["tokenizer"] |
| |
|
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=model_path, |
| | max_seq_length=4096, |
| | load_in_4bit=False, |
| | load_in_8bit=False, |
| | full_finetuning=False, |
| | ) |
| | _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer |
| | return model, tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| | def infer_reasonableness( |
| | reference_summary: str, |
| | generated_summary: str, |
| | readability_level: str, |
| | subclaim_text: str, |
| | result: int, |
| | model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3", |
| | ): |
| | """ |
| | Given the reference summary, generated summary, readability level, subclaim, and its result (0/1), |
| | predict reasonableness: reasonable / partially_reasonable / unreasonable, plus justification. |
| | """ |
| | model, tokenizer = load_finetuned_model(model_path) |
| |
|
| | |
| | prompt = f""" |
| | You are an impartial medical summarization evaluator. |
| | |
| | Goal: |
| | Decide whether the inclusion or omission of ONE specific subclaim from the reference summary is *reasonable*, given the readability level of the generated summary. |
| | |
| | Readability Criteria: |
| | - Easy: for non-medical readers; emphasize main story and outcomes; omit numerical data, anatomy, and test details. |
| | - Intermediate: for general educated readers; keep main findings but simplify phrasing. |
| | - Hard: for clinical or technical readers; maintain diagnostic accuracy and essential quantitative or anatomic content. |
| | |
| | Judging rules: |
| | * Base your decision strictly on what appears in the generated summary. |
| | * If result = 0 (subclaim omitted) and the omitted detail is clearly technical or numerical for the given level, choose "reasonable". |
| | * If result = 0 and the subclaim is essential to the main story, choose "unreasonable". |
| | * Stay consistent between `result`, justification, and readability level. |
| | |
| | ### Inputs |
| | Readability Level: {readability_level} |
| | Reference Summary: {reference_summary} |
| | Generated Summary: {generated_summary} |
| | Subclaim: "{subclaim_text}" |
| | Result: {result} # 1 = supported (included), 0 = omitted |
| | |
| | ### Task |
| | Respond **only** with the following JSON object: |
| | |
| | {{ |
| | "reasonableness": "<reasonable | partially_reasonable | unreasonable>", |
| | "justification": "<short clear explanation>" |
| | }} |
| | """.strip() |
| |
|
| | messages = [{"role": "user", "content": prompt + "\n"}] |
| |
|
| | chat_text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=False, |
| | ) |
| |
|
| | inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") |
| |
|
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | temperature=0.2, |
| | top_p=0.8, |
| | top_k=5, |
| | do_sample=False, |
| | ) |
| |
|
| | output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
| | output_text = output_text.split("</think>")[1].strip().replace("```json", "").replace("```", "") |
| | |
| | try: |
| | parsed = json.loads(output_text) |
| | except Exception: |
| | |
| | parsed = output_text |
| | return parsed |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | |
| | |
| | import json |
| | with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_es.json', 'r') as f: |
| | multiclinsum_gs_train_es_data = json.load(f) |
| | ref_summaries={} |
| | fulltexts={} |
| | for item in multiclinsum_gs_train_es_data: |
| | ref_summaries[item['id']]=item['summary'] |
| | fulltexts[item['id']]=item['fulltext'] |
| | |
| | generated_summaries = {} |
| | with open('/home/mshahidul/readctrl/data/hand_create_gpt5_other_model/synthetic_data_es_raw_592.json', 'r') as f: |
| | synthetic_data_es_raw_592 = json.load(f) |
| | for item in synthetic_data_es_raw_592: |
| | for version in ['easy', 'intermediate', 'hard']: |
| | generated_summaries[(item['id'], version)] = item['readability_versions'][version]['text'] |
| | |
| | with open("/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json", 'r') as f: |
| | qwen3_32B_results = json.load(f) |
| | full_res = [] |
| | save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/completeness_resonability_check_100_qwen3-32B_v4.json" |
| | import tqdm |
| | for idx, item in tqdm.tqdm(enumerate(qwen3_32B_results)): |
| | print(f"Processing item {idx + 1}/{len(qwen3_32B_results)}") |
| | reference_summary = ref_summaries[item['id']] |
| | fulltext = fulltexts[item['id']] |
| | generated_summary = generated_summaries[(item['id'], item['version'])] |
| | temp_res = [] |
| | for item2 in item['completeness']['results']: |
| | subclaim_text = item2['subclaim']['subclaim'] |
| | result = item2['result'] |
| | if result =="1": |
| | continue |
| | response = infer_reasonableness( |
| | reference_summary, |
| | generated_summary, |
| | item['version'], |
| | subclaim_text, |
| | result, |
| | model_path="/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3", |
| | ) |
| | temp_res.append({ |
| | 'id':item2['subclaim']['id'], |
| | "subclaim": subclaim_text, |
| | "result": result, |
| | "reasonableness": response |
| | }) |
| | full_res.append({ |
| | "id": item['id'], |
| | "version": item['version'], |
| | "completeness": { |
| | "results": temp_res |
| | } |
| | }) |
| | if len(full_res)%10==0: |
| | with open(save_path, 'w') as f: |
| | json.dump(full_res, f, indent=2, ensure_ascii=False) |
| |
|
| | with open(save_path, 'w') as f: |
| | json.dump(full_res, f, indent=2, ensure_ascii=False) |
| |
|
| |
|