| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "5" |
| import argparse |
| import json |
| import re |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional |
|
|
| from vllm import LLM, SamplingParams |
| from transformers import AutoTokenizer |
|
|
|
|
| |
| BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" |
| PROMPT_DIR = os.path.join(BASE_DIR, "prompt_en") |
| TEST_JSON = os.path.join(BASE_DIR, "dataset", "en", "test_en.json") |
| RESULTS_DIR = os.path.join(BASE_DIR, "results", "en") |
|
|
| SOURCE_LANG = "English" |
|
|
| |
| LABEL_TO_PROMPT_FILE: Dict[str, str] = { |
| "low_health_literacy": "prompt_low", |
| "intermediate_health_literacy": "prompt_intermediate", |
| "proficient_health_literacy": "prompt_proficient", |
| } |
|
|
| LABEL_TO_READABILITY: Dict[str, str] = { |
| "low_health_literacy": ( |
| "Low Health Literacy (High Readability): individuals needing the simplest " |
| "terms for immediate action, using 'living room' language, one idea per " |
| "sentence, and focusing only on need-to-know information from the Gold Summary." |
| ), |
| "intermediate_health_literacy": ( |
| "Intermediate Health Literacy (Medium Readability): the general public at a " |
| "news-reading level, with standard vocabulary and some common medical terms, " |
| "and a balanced level of detail led by the Gold Summary." |
| ), |
| "proficient_health_literacy": ( |
| "Proficient Health Literacy (Low Readability): researchers, clinicians, or " |
| "highly informed patients, using technical and academic language, high " |
| "information density, and full clinical nuance and terminology from the " |
| "Source Text." |
| ), |
| } |
|
|
|
|
| def load_prompts(prompt_dir: str) -> Dict[str, str]: |
| prompts: Dict[str, str] = {} |
| for label, filename in LABEL_TO_PROMPT_FILE.items(): |
| path = os.path.join(prompt_dir, filename) |
| if os.path.isfile(path): |
| with open(path, "r", encoding="utf-8") as f: |
| prompts[label] = f.read() |
| else: |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
| return prompts |
|
|
|
|
| def build_generation_user_message( |
| prompt_template: str, |
| full_text: str, |
| gold_summary: str, |
| source_lang: str = SOURCE_LANG, |
| ) -> str: |
| return ( |
| prompt_template.replace("{full_text}", full_text) |
| .replace("{gold_summary}", gold_summary) |
| .replace("{source_lang}", source_lang) |
| ) |
|
|
|
|
| def extract_summary_from_json_str(raw: str, expected_key: str) -> str: |
| """ |
| Extract the summary string from a JSON-like model output. |
| Falls back to returning the raw text if parsing fails. |
| """ |
| text = raw.strip() |
|
|
| |
| if text.startswith("```"): |
| |
| lines = text.splitlines() |
| |
| if lines: |
| lines = lines[1:] |
| if lines and lines[-1].strip().startswith("```"): |
| lines = lines[:-1] |
| text = "\n".join(lines).strip() |
|
|
| |
| start = text.find("{") |
| end = text.rfind("}") |
| if start != -1 and end != -1 and end > start: |
| candidate = text[start : end + 1] |
| else: |
| candidate = text |
|
|
| |
| try: |
| obj = json.loads(candidate) |
| if isinstance(obj, dict): |
| if expected_key in obj and isinstance(obj[expected_key], str): |
| return obj[expected_key].strip() |
| |
| if len(obj) == 1: |
| val = next(iter(obj.values())) |
| if isinstance(val, str): |
| return val.strip() |
| except Exception: |
| pass |
|
|
| |
| key_pattern = re.escape(expected_key) |
| m = re.search(rf'"{key_pattern}"\s*:\s*"([^"]*)"', candidate, re.DOTALL) |
| if m: |
| return m.group(1).strip() |
|
|
| return raw.strip() |
|
|
|
|
| def build_critique_user_message( |
| label: str, |
| current_summary: str, |
| ) -> str: |
| readability = LABEL_TO_READABILITY.get(label, label) |
| return ( |
| "You are an expert medical editor and Health Literacy specialist.\n\n" |
| f"Read the following patient-facing summary and critique its **readability** " |
| f"for this audience:\n\n{readability}\n\n" |
| "Instructions:\n" |
| "1. Focus ONLY on clarity, plain language, sentence structure, and suitability " |
| "for the target reader.\n" |
| "2. Do NOT add new medical facts that are not already present.\n" |
| "3. Identify concrete issues and suggest improvements as bullet points.\n\n" |
| "Summary to critique:\n" |
| f"{current_summary}\n\n" |
| "Now provide a concise critique in bullet points." |
| ) |
|
|
|
|
| def build_revision_user_message( |
| label: str, |
| current_summary: str, |
| critique: str, |
| ) -> str: |
| readability = LABEL_TO_READABILITY.get(label, label) |
| |
| expected_key = label |
| return ( |
| "You are an expert medical editor and Health Literacy specialist.\n\n" |
| f"Goal: Rewrite the summary so it better matches this readability requirement:\n\n" |
| f"{readability}\n\n" |
| "Use ONLY the information already present in the original summary. " |
| "Do NOT introduce new clinical facts.\n\n" |
| "Original summary:\n" |
| f"{current_summary}\n\n" |
| "Your previous readability critique:\n" |
| f"{critique}\n\n" |
| "Now produce an improved version of the summary that addresses the critique.\n" |
| "Output **JSON only** with this exact structure:\n" |
| f'{{\n "{expected_key}": "..." \n}}\n' |
| ) |
|
|
|
|
| def generate_single( |
| llm: LLM, |
| sampling_params: SamplingParams, |
| tokenizer, |
| user_content: str, |
| ) -> str: |
| chat = [{"role": "user", "content": user_content}] |
| prompt = tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True |
| ) |
| outputs = llm.generate([prompt], sampling_params=sampling_params) |
| |
| return outputs[0].outputs[0].text.strip() |
|
|
|
|
| def self_refine_example( |
| llm: LLM, |
| tokenizer, |
| item: Dict[str, Any], |
| prompts: Dict[str, str], |
| num_iterations: int, |
| gen_sampling: SamplingParams, |
| critique_sampling: SamplingParams, |
| revise_sampling: SamplingParams, |
| source_lang: str = SOURCE_LANG, |
| ) -> Dict[str, Any]: |
| label: str = item.get("label") |
| doc_id = item.get("doc_id") |
| fulltext = item.get("fulltext", "") |
| gold_summary = item.get("summary", "") |
| gold_gen_text = item.get("gen_text", "") |
|
|
| if label not in prompts: |
| return { |
| "doc_id": doc_id, |
| "label": label, |
| "error": f"Unknown label: {label}", |
| } |
|
|
| prompt_template = prompts[label] |
| history: List[Dict[str, Any]] = [] |
|
|
| |
| gen_user = build_generation_user_message( |
| prompt_template=prompt_template, |
| full_text=fulltext, |
| gold_summary=gold_summary, |
| source_lang=source_lang, |
| ) |
| raw_initial = generate_single( |
| llm=llm, |
| sampling_params=gen_sampling, |
| tokenizer=tokenizer, |
| user_content=gen_user, |
| ) |
| current_summary = extract_summary_from_json_str(raw_initial, expected_key=label) |
|
|
| history.append( |
| { |
| "iteration": 0, |
| "summary": current_summary, |
| "raw_model_output": raw_initial, |
| } |
| ) |
|
|
| |
| for i in range(1, num_iterations + 1): |
| |
| critique_user = build_critique_user_message(label=label, current_summary=current_summary) |
| raw_critique = generate_single( |
| llm=llm, |
| sampling_params=critique_sampling, |
| tokenizer=tokenizer, |
| user_content=critique_user, |
| ) |
|
|
| |
| revise_user = build_revision_user_message( |
| label=label, |
| current_summary=current_summary, |
| critique=raw_critique, |
| ) |
| raw_revised = generate_single( |
| llm=llm, |
| sampling_params=revise_sampling, |
| tokenizer=tokenizer, |
| user_content=revise_user, |
| ) |
| revised_summary = extract_summary_from_json_str(raw_revised, expected_key=label) |
|
|
| history.append( |
| { |
| "iteration": i, |
| "critique": raw_critique, |
| "revised_summary": revised_summary, |
| "raw_revision_output": raw_revised, |
| } |
| ) |
|
|
| current_summary = revised_summary |
|
|
| return { |
| "doc_id": doc_id, |
| "label": label, |
| "readability_requirement": LABEL_TO_READABILITY.get(label, label), |
| "gold_summary": gold_summary, |
| "gold_gen_text": gold_gen_text, |
| "initial_summary": history[0]["summary"], |
| "final_summary": current_summary, |
| "iterations": history, |
| "error": None, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser( |
| description=( |
| "Run a self-refinement loop (generate → critique → revise) " |
| "with Qwen/Qwen3-4B-Instruct-2507 on test_en.json." |
| ) |
| ) |
| p.add_argument( |
| "--model-id", |
| type=str, |
| default="Qwen/Qwen3-4B-Instruct-2507", |
| help="Hugging Face model id or local path for the Qwen3 instruct model.", |
| ) |
| p.add_argument( |
| "--prompt-dir", |
| type=str, |
| default=PROMPT_DIR, |
| help="Directory containing prompt files (prompt_low, prompt_intermediate, prompt_proficient).", |
| ) |
| p.add_argument( |
| "--test-json", |
| type=str, |
| default=TEST_JSON, |
| help="Path to the input test/dataset JSON file.", |
| ) |
| p.add_argument( |
| "--src-lang", |
| type=str, |
| default=SOURCE_LANG, |
| help="Source language name used in the generation prompt (e.g. English).", |
| ) |
| p.add_argument( |
| "--num-iterations", |
| type=int, |
| default=5, |
| help="Number of critique+revise iterations to run per example.", |
| ) |
| p.add_argument( |
| "--max-new-tokens", |
| type=int, |
| default=512, |
| help="Maximum new tokens for summary generation.", |
| ) |
| p.add_argument( |
| "--critique-max-new-tokens", |
| type=int, |
| default=256, |
| help="Maximum new tokens for critique generation.", |
| ) |
| p.add_argument( |
| "--revise-max-new-tokens", |
| type=int, |
| default=512, |
| help="Maximum new tokens for revision generation.", |
| ) |
| p.add_argument( |
| "--temperature", |
| type=float, |
| default=0.1, |
| help="Sampling temperature for generation and revision.", |
| ) |
| p.add_argument( |
| "--critique-temperature", |
| type=float, |
| default=0.3, |
| help="Sampling temperature for critique (usually lower).", |
| ) |
| p.add_argument( |
| "--limit", |
| type=int, |
| default=None, |
| help="Optional limit on number of examples from test_en.json (for debugging).", |
| ) |
| p.add_argument( |
| "--output-file", |
| type=str, |
| default=None, |
| help=( |
| "Optional path for the main results JSON file. " |
| "If not set, a timestamped name in the results directory is used." |
| ), |
| ) |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
| print("Loading prompts from", args.prompt_dir) |
| prompts = load_prompts(args.prompt_dir) |
|
|
| print("Loading test data from", args.test_json) |
| with open(args.test_json, "r", encoding="utf-8") as f: |
| test_list: List[Dict[str, Any]] = json.load(f) |
|
|
| if args.limit is not None: |
| test_list = test_list[: args.limit] |
| print(f"Limiting to first {len(test_list)} examples.") |
| else: |
| print(f"Total examples: {len(test_list)}") |
|
|
| print("Loading tokenizer and model:", args.model_id) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
| llm = LLM( |
| model=args.model_id, |
| trust_remote_code=True, |
| ) |
|
|
| gen_sampling = SamplingParams( |
| temperature=args.temperature, |
| max_tokens=args.max_new_tokens, |
| n=1, |
| ) |
| critique_sampling = SamplingParams( |
| temperature=args.critique_temperature, |
| max_tokens=args.critique_max_new_tokens, |
| n=1, |
| ) |
| revise_sampling = SamplingParams( |
| temperature=args.temperature, |
| max_tokens=args.revise_max_new_tokens, |
| n=1, |
| ) |
|
|
| results: List[Dict[str, Any]] = [] |
|
|
| total = len(test_list) |
| for idx, item in enumerate(test_list): |
| print(f"\n=== Processing example {idx + 1}/{total} (doc_id={item.get('doc_id')}, label={item.get('label')}) ===") |
| try: |
| example_result = self_refine_example( |
| llm=llm, |
| tokenizer=tokenizer, |
| item=item, |
| prompts=prompts, |
| num_iterations=args.num_iterations, |
| gen_sampling=gen_sampling, |
| critique_sampling=critique_sampling, |
| revise_sampling=revise_sampling, |
| source_lang=args.src_lang, |
| ) |
| except Exception as e: |
| example_result = { |
| "doc_id": item.get("doc_id"), |
| "label": item.get("label"), |
| "error": f"Exception during self-refinement: {e}", |
| } |
| results.append(example_result) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| if args.output_file: |
| out_path = args.output_file |
| base, ext = os.path.splitext(out_path) |
| if not ext: |
| out_path = base + ".json" |
| base = out_path.rsplit(".", 1)[0] |
| summary_path = base + "_summary.json" |
| else: |
| out_path = os.path.join(RESULTS_DIR, f"self_refine_qwen3_{timestamp}.json") |
| summary_path = os.path.join( |
| RESULTS_DIR, f"self_refine_qwen3_{timestamp}_summary.json" |
| ) |
|
|
| print("\nSaving detailed self-refinement results to", out_path) |
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| summary: Dict[str, Any] = { |
| "model_id": args.model_id, |
| "prompt_dir": os.path.abspath(args.prompt_dir), |
| "test_json": os.path.abspath(args.test_json), |
| "src_lang": args.src_lang, |
| "num_test_samples": len(test_list), |
| "results_file": out_path, |
| "timestamp": timestamp, |
| "num_iterations": args.num_iterations, |
| "max_new_tokens": args.max_new_tokens, |
| "critique_max_new_tokens": args.critique_max_new_tokens, |
| "revise_max_new_tokens": args.revise_max_new_tokens, |
| "temperature": args.temperature, |
| "critique_temperature": args.critique_temperature, |
| } |
|
|
| print("Saving summary metadata to", summary_path) |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
| print("\nDone.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|