readctrl / code /attribution_evalV2.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import json
import torch
from unsloth import FastLanguageModel
import tqdm
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
"""Load and cache the fine-tuned model + tokenizer."""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=8192,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
def build_inference_prompt(
reference_full_text,
generated_summary,
subclaim_id,
subclaim_text,
subclaim_result,
difficulty_level
):
"""
Build a standardized inference prompt for single‑subclaim evaluation.
Use after fine‑tuning to assess new examples consistently.
"""
inference_prompt = f"""
### **SYSTEM / ROLE INSTRUCTION**
You are a **medical factuality and attribution evaluator**.
You will analyze one subclaim from a generated medical summary.
Each subclaim includes a `"result"` flag:
- `1` → Supported by the reference text (no reasonableness check required)
- `0` → Unsupported by the reference text (evaluate scope and validity)
Your task is to decide, for unsupported subclaims, whether the new information
is a *reasonable addition* given the specified readability level:
**easy**, **intermediate**, or **hard**.
---
### **READABILITY GUIDELINES**
| Level | Audience | Style | Allowable Additions |
| :-- | :-- | :-- | :-- |
| **Easy (FH 70–100)** | General public | Simple, concrete | Broad clarifications only; no factual innovations |
| **Intermediate (FH 50–69)** | Educated nonspecialist | Moderate precision | Limited clarifications consistent with the text |
| **Hard (FH 0–49)** | Professionals | Formal, technical | Must be strictly supported by evidence |
---
### **INPUT**
Readability Level: {difficulty_level}
Reference Full Text:
{reference_full_text}
Generated Summary:
{generated_summary}
Subclaim Info:
{{
"subclaim_id": {subclaim_id},
"subclaim": "{subclaim_text}",
"result": {subclaim_result}
}}
---
### **TASK INSTRUCTIONS**
- If `"result": 1"`, respond with `"not_applicable"` and justify briefly
(e.g., *"supported, no evaluation required"*).
- If `"result": 0"`, classify reasonableness:
- `"reasonable"` → legitimate simplification consistent with the readability level
- `"partially_reasonable"` → benign rephrasing
- `"unreasonable"` → misleading, speculative, or contradicted by the source
Provide a **short 1–2 sentence justification**.
---
### **EXPECTED OUTPUT (JSON ONLY)**
```json
{{
"evaluation": {{
"subclaim_id": {subclaim_id},
"subclaim": "{subclaim_text}",
"result": {subclaim_result},
"reasonableness": "<reasonable | partially_reasonable | unreasonable | not_applicable>",
"justification": "<brief justification>"
}}
}}
""".strip()
return inference_prompt
def infer_attribution_reasonableness(prompt: str, model_path: str):
"""Run inference using the fine-tuned model with attribution prompt."""
model, tokenizer = load_finetuned_model(model_path)
messages = [{"role": "user", "content": prompt + "\n"}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.2,
top_p=0.8,
top_k=5,
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip().replace("```json", "").replace("```", "")
try:
parsed = json.loads(output_text)
except Exception:
parsed = output_text
return parsed
file_synth = "/home/mshahidul/readctrl/data/training_data_subclaim_verifier/synthetic_data_es_subclaims_100.json"
file_qwen_results = "/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json"
save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/attribution_resonability_results_100_qwen3-32B_v2.json"
with open(file_synth, 'r') as f:
synthetic_data = json.load(f)
with open(file_qwen_results, 'r') as f:
qwen3_32B_results = json.load(f)
dict1={}
for item in qwen3_32B_results:
version=item['version']
dict1[(item['id'], version)] = item['attribution']['results']
res = []
if os.path.exists(save_path):
with open(save_path, 'r') as f:
res = json.load(f)
print(f"🔁 Resuming from {len(res)} entries")
existing = set((e["id"], e["difficulty_level"]) for e in res)
for ind in tqdm.tqdm(range(0, 100)):
entry = synthetic_data[ind]
for level in ["easy", "intermediate", "hard"]:
subclaims_results = dict1[(entry["id"], level)]
if (entry["id"], level) in existing:
print(f"⏭️ Skipping {entry['id']} ({level})")
continue
ref_full_text = entry["full_text"]
generated_summary = entry["readability_versions"][level]["text"]
temp=[]
for subclaim in subclaims_results:
subclaim_id = subclaim['subclaim']['id']
subclaim_text = subclaim['subclaim']['subclaim']
subclaim_result = subclaim['result']
prompt = build_inference_prompt(
ref_full_text,
generated_summary,
subclaim_id,
subclaim_text,
subclaim_result,
level
)
if subclaim_result=="1":
temp.append({
"subclaim_id": subclaim_id,
"subclaim_text": subclaim_text,
"response": "not_applicable"
})
continue
response = infer_attribution_reasonableness(prompt,"/home/mshahidul/readctrl_model/qwen3-32B_subclaims-attribution_resonability_check_8kCtx_v1")
temp.append({
"subclaim_id": subclaim_id,
"subclaim_text": subclaim_text,
"response": response
})
res.append({
"id": entry["id"],
"difficulty_level": level,
"results": temp
})
if len(res) % 10 == 0:
with open(save_path, 'w') as f:
json.dump(res, f, indent=2, ensure_ascii=False)
print(f"💾 Saved after {len(res)} entries")
with open(save_path, 'w') as f:
json.dump(res, f, indent=2, ensure_ascii=False)