readCtrl_lambda / code /finetune-inference /old /completeness_conciseness_attribution_cal.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from unsloth import FastLanguageModel
import json
# Optional: wrap model/tokenizer in a singleton pattern for repeated use
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
"""Load and cache your fine鈥憈uned model + tokenizer."""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=4092,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
def infer_subclaim(text: str, subclaim: str, model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-verifier_lora_nonreasoning", cuda_device: str = "0") -> str:
"""
Given a medical text and a subclaim, returns '1' if the text supports the subclaim, otherwise '0'.
"""
model, tokenizer = load_finetuned_model(model_path)
# Build prompt (the same structure you trained on)
prompt = f"""
Given the following medical text and subclaim, decide if the text supports the subclaim.
Text: {text}
Subclaim: {subclaim}
Respond only with 1 if the text supports the subclaim, otherwise 0.
""".strip()
messages = [{"role": "user", "content": prompt + "\n"}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=10,
temperature=0.1,
top_p=0.8,
top_k=5,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
return output_text.split("</think>")[1].strip()
if __name__ == "__main__":
# example_text = (
# "Una ni帽a nacida a las 34 semanas de gestaci贸n precis贸 intubaci贸n y ventilaci贸n al nacer..."
# )
# example_subclaim = "La paciente es una reci茅n nacida prematura."
def process_completeness(example,version):
example_text = example["readability_versions"][version]['text']
example_subclaims = example['ref_summary']["subclaims"]
# print("Input text:", example_text)
res=[]
total=0
correct=0
for example_subclaim in example_subclaims:
result = infer_subclaim(example_text, example_subclaim)
if "1" in result:
correct+=1
total+=1
elif "0" in result:
total+=1
res.append({
"subclaim": example_subclaim,
"result": result
})
return {"metric": "completeness", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0}
def process_conciseness(example, version):
example_text = example["ref_summary"]['text']
example_subclaims = example["readability_versions"][version]["subclaims"]
# print("Input text:", example_text)
res=[]
total=0
correct=0
for example_subclaim in example_subclaims:
result = infer_subclaim(example_text, example_subclaim)
if "1" in result:
correct+=1
total+=1
elif "0" in result:
total+=1
res.append({
"subclaim": example_subclaim,
"result": result
})
return {"metric": "conciseness", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0}
def process_attribution(example, version):
example_text = example['full_text']
example_subclaims = example["readability_versions"][version]["subclaims"]
# print("Input text:", example_text)
res=[]
total=0
correct=0
for example_subclaim in example_subclaims:
result = infer_subclaim(example_text, example_subclaim)
if "1" in result:
correct+=1
total+=1
elif "0" in result:
total+=1
res.append({
"subclaim": example_subclaim,
"result": result
})
return {"metric": "attribution", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0}
with open("/home/mshahidul/readctrl/data/training_data_subclaim_verifier/synthetic_data_es_subclaims_100.json", "r", encoding="utf-8") as f:
data = json.load(f)
import tqdm
full_data_results = []
save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json"
for item in tqdm.tqdm(data):
print(f"Processing item ID: {item['id']}")
for version in ["easy", "intermediate", "hard"]:
completeness=process_completeness(item,version)
conciseness=process_conciseness(item,version)
attribution=process_attribution(item,version)
full_data_results.append({
"id": item["id"],
"version": version,
"completeness": completeness,
"conciseness": conciseness,
"attribution": attribution
})
if len(full_data_results)%5==0:
with open(save_path, "w", encoding="utf-8") as f:
json.dump(full_data_results, f, indent=4, ensure_ascii=False)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(full_data_results, f, indent=4, ensure_ascii=False)