File size: 7,163 Bytes
9c6961c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from unsloth import FastLanguageModel
import json
# ===========================
# GPU SETTINGS
# ===========================
# ===========================
# MODEL LOADING (CACHED)
# ===========================
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
"""Load and cache the fine-tuned model + tokenizer."""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=4096,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# ===========================
# INFERENCE FUNCTION
# ===========================
def infer_reasonableness(
reference_summary: str,
generated_summary: str,
readability_level: str,
subclaim_text: str,
result: int,
model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
):
"""
Given the reference summary, generated summary, readability level, subclaim, and its result (0/1),
predict reasonableness: reasonable / partially_reasonable / unreasonable, plus justification.
"""
model, tokenizer = load_finetuned_model(model_path)
# ---- Build inference prompt (same structure as training) ----
prompt = f"""
You are an impartial medical summarization evaluator.
Goal:
Decide whether the inclusion or omission of ONE specific subclaim from the reference summary is *reasonable*, given the readability level of the generated summary.
Readability Criteria:
- Easy: for non-medical readers; emphasize main story and outcomes; omit numerical data, anatomy, and test details.
- Intermediate: for general educated readers; keep main findings but simplify phrasing.
- Hard: for clinical or technical readers; maintain diagnostic accuracy and essential quantitative or anatomic content.
Judging rules:
* Base your decision strictly on what appears in the generated summary.
* If result = 0 (subclaim omitted) and the omitted detail is clearly technical or numerical for the given level, choose "reasonable".
* If result = 0 and the subclaim is essential to the main story, choose "unreasonable".
* Stay consistent between `result`, justification, and readability level.
### Inputs
Readability Level: {readability_level}
Reference Summary: {reference_summary}
Generated Summary: {generated_summary}
Subclaim: "{subclaim_text}"
Result: {result} # 1 = supported (included), 0 = omitted
### Task
Respond **only** with the following JSON object:
{{
"reasonableness": "<reasonable | partially_reasonable | unreasonable>",
"justification": "<short clear explanation>"
}}
""".strip()
messages = [{"role": "user", "content": prompt + "\n"}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # important for Unsloth chat template
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
# ---- Generate output ----
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.2,
top_p=0.8,
top_k=5,
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
output_text = output_text.split("</think>")[1].strip().replace("```json", "").replace("```", "")
# ---- Extract model JSON output ----
try:
parsed = json.loads(output_text)
except Exception:
# print("Failed to parse JSON from model output. Returning raw text.\n\n")
parsed = output_text
return parsed
# ===========================
# EXAMPLE USAGE
# ===========================
if __name__ == "__main__":
# reference_summary = "Una ni帽a nacida a las 34 semanas de gestaci贸n precis贸 intubaci贸n..."
# generated_summary = "Esta es la historia de una ni帽a que naci贸 antes de tiempo, a las 34 semanas..."
# subclaim_text = "La paciente presentaba hiperinsulinismo en el per铆odo neonatal."
# readability_level = "easy"
# result = 0 # omitted
import json
with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_es.json', 'r') as f:
multiclinsum_gs_train_es_data = json.load(f)
ref_summaries={}
fulltexts={}
for item in multiclinsum_gs_train_es_data:
ref_summaries[item['id']]=item['summary']
fulltexts[item['id']]=item['fulltext']
generated_summaries = {}
with open('/home/mshahidul/readctrl/data/hand_create_gpt5_other_model/synthetic_data_es_raw_592.json', 'r') as f:
synthetic_data_es_raw_592 = json.load(f)
for item in synthetic_data_es_raw_592:
for version in ['easy', 'intermediate', 'hard']:
generated_summaries[(item['id'], version)] = item['readability_versions'][version]['text']
# /home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json
with open("/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json", 'r') as f:
qwen3_32B_results = json.load(f)
full_res = []
save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/completeness_resonability_check_100_qwen3-32B_v4.json"
import tqdm
for idx, item in tqdm.tqdm(enumerate(qwen3_32B_results)):
print(f"Processing item {idx + 1}/{len(qwen3_32B_results)}")
reference_summary = ref_summaries[item['id']]
fulltext = fulltexts[item['id']]
generated_summary = generated_summaries[(item['id'], item['version'])]
temp_res = []
for item2 in item['completeness']['results']:
subclaim_text = item2['subclaim']['subclaim']
result = item2['result']
if result =="1":
continue
response = infer_reasonableness(
reference_summary,
generated_summary,
item['version'],
subclaim_text,
result,
model_path="/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
)
temp_res.append({
'id':item2['subclaim']['id'],
"subclaim": subclaim_text,
"result": result,
"reasonableness": response
})
full_res.append({
"id": item['id'],
"version": item['version'],
"completeness": {
"results": temp_res
}
})
if len(full_res)%10==0:
with open(save_path, 'w') as f:
json.dump(full_res, f, indent=2, ensure_ascii=False)
with open(save_path, 'w') as f:
json.dump(full_res, f, indent=2, ensure_ascii=False)
|