readctrl / code /finetune-inference /old /completeness_reasoning_v2.py
shahidul034's picture
Add files using upload-large-folder tool
9c6961c verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from unsloth import FastLanguageModel
import json
# ===========================
# GPU SETTINGS
# ===========================
# ===========================
# MODEL LOADING (CACHED)
# ===========================
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
"""Load and cache the fine-tuned model + tokenizer."""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=4096,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# ===========================
# INFERENCE FUNCTION
# ===========================
def infer_reasonableness(
reference_summary: str,
generated_summary: str,
readability_level: str,
subclaim_text: str,
result: int,
model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
):
"""
Given the reference summary, generated summary, readability level, subclaim, and its result (0/1),
predict reasonableness: reasonable / partially_reasonable / unreasonable, plus justification.
"""
model, tokenizer = load_finetuned_model(model_path)
# ---- Build inference prompt (same structure as training) ----
prompt = f"""
You are an impartial medical summarization evaluator.
Goal:
Decide whether the inclusion or omission of ONE specific subclaim from the reference summary is *reasonable*, given the readability level of the generated summary.
Readability Criteria:
- Easy: for non-medical readers; emphasize main story and outcomes; omit numerical data, anatomy, and test details.
- Intermediate: for general educated readers; keep main findings but simplify phrasing.
- Hard: for clinical or technical readers; maintain diagnostic accuracy and essential quantitative or anatomic content.
Judging rules:
* Base your decision strictly on what appears in the generated summary.
* If result = 0 (subclaim omitted) and the omitted detail is clearly technical or numerical for the given level, choose "reasonable".
* If result = 0 and the subclaim is essential to the main story, choose "unreasonable".
* Stay consistent between `result`, justification, and readability level.
### Inputs
Readability Level: {readability_level}
Reference Summary: {reference_summary}
Generated Summary: {generated_summary}
Subclaim: "{subclaim_text}"
Result: {result} # 1 = supported (included), 0 = omitted
### Task
Respond **only** with the following JSON object:
{{
"reasonableness": "<reasonable | partially_reasonable | unreasonable>",
"justification": "<short clear explanation>"
}}
""".strip()
messages = [{"role": "user", "content": prompt + "\n"}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # important for Unsloth chat template
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
# ---- Generate output ----
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.2,
top_p=0.8,
top_k=5,
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
output_text = output_text.split("</think>")[1].strip().replace("```json", "").replace("```", "")
# ---- Extract model JSON output ----
try:
parsed = json.loads(output_text)
except Exception:
# print("Failed to parse JSON from model output. Returning raw text.\n\n")
parsed = output_text
return parsed
# ===========================
# EXAMPLE USAGE
# ===========================
if __name__ == "__main__":
# reference_summary = "Una ni帽a nacida a las 34 semanas de gestaci贸n precis贸 intubaci贸n..."
# generated_summary = "Esta es la historia de una ni帽a que naci贸 antes de tiempo, a las 34 semanas..."
# subclaim_text = "La paciente presentaba hiperinsulinismo en el per铆odo neonatal."
# readability_level = "easy"
# result = 0 # omitted
import json
with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_es.json', 'r') as f:
multiclinsum_gs_train_es_data = json.load(f)
ref_summaries={}
fulltexts={}
for item in multiclinsum_gs_train_es_data:
ref_summaries[item['id']]=item['summary']
fulltexts[item['id']]=item['fulltext']
generated_summaries = {}
with open('/home/mshahidul/readctrl/data/hand_create_gpt5_other_model/synthetic_data_es_raw_592.json', 'r') as f:
synthetic_data_es_raw_592 = json.load(f)
for item in synthetic_data_es_raw_592:
for version in ['easy', 'intermediate', 'hard']:
generated_summaries[(item['id'], version)] = item['readability_versions'][version]['text']
# /home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json
with open("/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json", 'r') as f:
qwen3_32B_results = json.load(f)
full_res = []
save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/completeness_resonability_check_100_qwen3-32B_v4.json"
import tqdm
for idx, item in tqdm.tqdm(enumerate(qwen3_32B_results)):
print(f"Processing item {idx + 1}/{len(qwen3_32B_results)}")
reference_summary = ref_summaries[item['id']]
fulltext = fulltexts[item['id']]
generated_summary = generated_summaries[(item['id'], item['version'])]
temp_res = []
for item2 in item['completeness']['results']:
subclaim_text = item2['subclaim']['subclaim']
result = item2['result']
if result =="1":
continue
response = infer_reasonableness(
reference_summary,
generated_summary,
item['version'],
subclaim_text,
result,
model_path="/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
)
temp_res.append({
'id':item2['subclaim']['id'],
"subclaim": subclaim_text,
"result": result,
"reasonableness": response
})
full_res.append({
"id": item['id'],
"version": item['version'],
"completeness": {
"results": temp_res
}
})
if len(full_res)%10==0:
with open(save_path, 'w') as f:
json.dump(full_res, f, indent=2, ensure_ascii=False)
with open(save_path, 'w') as f:
json.dump(full_res, f, indent=2, ensure_ascii=False)