readCtrl_lambda / code /fine_tune_sft_dpo /self_refine_qwen3_vllm.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import argparse
import json
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
# Base paths follow the existing project layout
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_en")
TEST_JSON = os.path.join(BASE_DIR, "dataset", "en", "test_en.json")
RESULTS_DIR = os.path.join(BASE_DIR, "results", "en")
SOURCE_LANG = "English"
# Reuse the same label → prompt mapping used elsewhere
LABEL_TO_PROMPT_FILE: Dict[str, str] = {
"low_health_literacy": "prompt_low",
"intermediate_health_literacy": "prompt_intermediate",
"proficient_health_literacy": "prompt_proficient",
}
LABEL_TO_READABILITY: Dict[str, str] = {
"low_health_literacy": (
"Low Health Literacy (High Readability): individuals needing the simplest "
"terms for immediate action, using 'living room' language, one idea per "
"sentence, and focusing only on need-to-know information from the Gold Summary."
),
"intermediate_health_literacy": (
"Intermediate Health Literacy (Medium Readability): the general public at a "
"news-reading level, with standard vocabulary and some common medical terms, "
"and a balanced level of detail led by the Gold Summary."
),
"proficient_health_literacy": (
"Proficient Health Literacy (Low Readability): researchers, clinicians, or "
"highly informed patients, using technical and academic language, high "
"information density, and full clinical nuance and terminology from the "
"Source Text."
),
}
def load_prompts(prompt_dir: str) -> Dict[str, str]:
prompts: Dict[str, str] = {}
for label, filename in LABEL_TO_PROMPT_FILE.items():
path = os.path.join(prompt_dir, filename)
if os.path.isfile(path):
with open(path, "r", encoding="utf-8") as f:
prompts[label] = f.read()
else:
raise FileNotFoundError(f"Prompt file not found: {path}")
return prompts
def build_generation_user_message(
prompt_template: str,
full_text: str,
gold_summary: str,
source_lang: str = SOURCE_LANG,
) -> str:
return (
prompt_template.replace("{full_text}", full_text)
.replace("{gold_summary}", gold_summary)
.replace("{source_lang}", source_lang)
)
def extract_summary_from_json_str(raw: str, expected_key: str) -> str:
"""
Extract the summary string from a JSON-like model output.
Falls back to returning the raw text if parsing fails.
"""
text = raw.strip()
# Strip markdown-style code fences if present
if text.startswith("```"):
# Remove leading fence line
lines = text.splitlines()
# Drop first line and any final fenced line
if lines:
lines = lines[1:]
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
text = "\n".join(lines).strip()
# Try to isolate the first {...} block
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = text[start : end + 1]
else:
candidate = text
# First attempt: strict JSON
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
if expected_key in obj and isinstance(obj[expected_key], str):
return obj[expected_key].strip()
# If only one key, fall back to that
if len(obj) == 1:
val = next(iter(obj.values()))
if isinstance(val, str):
return val.strip()
except Exception:
pass
# Second attempt: regex for "<expected_key>": "..."
key_pattern = re.escape(expected_key)
m = re.search(rf'"{key_pattern}"\s*:\s*"([^"]*)"', candidate, re.DOTALL)
if m:
return m.group(1).strip()
return raw.strip()
def build_critique_user_message(
label: str,
current_summary: str,
) -> str:
readability = LABEL_TO_READABILITY.get(label, label)
return (
"You are an expert medical editor and Health Literacy specialist.\n\n"
f"Read the following patient-facing summary and critique its **readability** "
f"for this audience:\n\n{readability}\n\n"
"Instructions:\n"
"1. Focus ONLY on clarity, plain language, sentence structure, and suitability "
"for the target reader.\n"
"2. Do NOT add new medical facts that are not already present.\n"
"3. Identify concrete issues and suggest improvements as bullet points.\n\n"
"Summary to critique:\n"
f"{current_summary}\n\n"
"Now provide a concise critique in bullet points."
)
def build_revision_user_message(
label: str,
current_summary: str,
critique: str,
) -> str:
readability = LABEL_TO_READABILITY.get(label, label)
# The JSON key is expected to match the label used in the dataset/prompts.
expected_key = label
return (
"You are an expert medical editor and Health Literacy specialist.\n\n"
f"Goal: Rewrite the summary so it better matches this readability requirement:\n\n"
f"{readability}\n\n"
"Use ONLY the information already present in the original summary. "
"Do NOT introduce new clinical facts.\n\n"
"Original summary:\n"
f"{current_summary}\n\n"
"Your previous readability critique:\n"
f"{critique}\n\n"
"Now produce an improved version of the summary that addresses the critique.\n"
"Output **JSON only** with this exact structure:\n"
f'{{\n "{expected_key}": "..." \n}}\n'
)
def generate_single(
llm: LLM,
sampling_params: SamplingParams,
tokenizer,
user_content: str,
) -> str:
chat = [{"role": "user", "content": user_content}]
prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
outputs = llm.generate([prompt], sampling_params=sampling_params)
# vLLM returns a list matching input prompts
return outputs[0].outputs[0].text.strip()
def self_refine_example(
llm: LLM,
tokenizer,
item: Dict[str, Any],
prompts: Dict[str, str],
num_iterations: int,
gen_sampling: SamplingParams,
critique_sampling: SamplingParams,
revise_sampling: SamplingParams,
source_lang: str = SOURCE_LANG,
) -> Dict[str, Any]:
label: str = item.get("label")
doc_id = item.get("doc_id")
fulltext = item.get("fulltext", "")
gold_summary = item.get("summary", "")
gold_gen_text = item.get("gen_text", "")
if label not in prompts:
return {
"doc_id": doc_id,
"label": label,
"error": f"Unknown label: {label}",
}
prompt_template = prompts[label]
history: List[Dict[str, Any]] = []
# Step 1: initial generation from the base prompt
gen_user = build_generation_user_message(
prompt_template=prompt_template,
full_text=fulltext,
gold_summary=gold_summary,
source_lang=source_lang,
)
raw_initial = generate_single(
llm=llm,
sampling_params=gen_sampling,
tokenizer=tokenizer,
user_content=gen_user,
)
current_summary = extract_summary_from_json_str(raw_initial, expected_key=label)
history.append(
{
"iteration": 0,
"summary": current_summary,
"raw_model_output": raw_initial,
}
)
# Iterative critique + revise loop
for i in range(1, num_iterations + 1):
# 2. Critique readability
critique_user = build_critique_user_message(label=label, current_summary=current_summary)
raw_critique = generate_single(
llm=llm,
sampling_params=critique_sampling,
tokenizer=tokenizer,
user_content=critique_user,
)
# 3. Revise based on critique
revise_user = build_revision_user_message(
label=label,
current_summary=current_summary,
critique=raw_critique,
)
raw_revised = generate_single(
llm=llm,
sampling_params=revise_sampling,
tokenizer=tokenizer,
user_content=revise_user,
)
revised_summary = extract_summary_from_json_str(raw_revised, expected_key=label)
history.append(
{
"iteration": i,
"critique": raw_critique,
"revised_summary": revised_summary,
"raw_revision_output": raw_revised,
}
)
current_summary = revised_summary
return {
"doc_id": doc_id,
"label": label,
"readability_requirement": LABEL_TO_READABILITY.get(label, label),
"gold_summary": gold_summary,
"gold_gen_text": gold_gen_text,
"initial_summary": history[0]["summary"],
"final_summary": current_summary,
"iterations": history,
"error": None,
}
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description=(
"Run a self-refinement loop (generate → critique → revise) "
"with Qwen/Qwen3-4B-Instruct-2507 on test_en.json."
)
)
p.add_argument(
"--model-id",
type=str,
default="Qwen/Qwen3-4B-Instruct-2507",
help="Hugging Face model id or local path for the Qwen3 instruct model.",
)
p.add_argument(
"--prompt-dir",
type=str,
default=PROMPT_DIR,
help="Directory containing prompt files (prompt_low, prompt_intermediate, prompt_proficient).",
)
p.add_argument(
"--test-json",
type=str,
default=TEST_JSON,
help="Path to the input test/dataset JSON file.",
)
p.add_argument(
"--src-lang",
type=str,
default=SOURCE_LANG,
help="Source language name used in the generation prompt (e.g. English).",
)
p.add_argument(
"--num-iterations",
type=int,
default=5,
help="Number of critique+revise iterations to run per example.",
)
p.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum new tokens for summary generation.",
)
p.add_argument(
"--critique-max-new-tokens",
type=int,
default=256,
help="Maximum new tokens for critique generation.",
)
p.add_argument(
"--revise-max-new-tokens",
type=int,
default=512,
help="Maximum new tokens for revision generation.",
)
p.add_argument(
"--temperature",
type=float,
default=0.1,
help="Sampling temperature for generation and revision.",
)
p.add_argument(
"--critique-temperature",
type=float,
default=0.3,
help="Sampling temperature for critique (usually lower).",
)
p.add_argument(
"--limit",
type=int,
default=None,
help="Optional limit on number of examples from test_en.json (for debugging).",
)
p.add_argument(
"--output-file",
type=str,
default=None,
help=(
"Optional path for the main results JSON file. "
"If not set, a timestamped name in the results directory is used."
),
)
return p.parse_args()
def main() -> None:
args = parse_args()
os.makedirs(RESULTS_DIR, exist_ok=True)
print("Loading prompts from", args.prompt_dir)
prompts = load_prompts(args.prompt_dir)
print("Loading test data from", args.test_json)
with open(args.test_json, "r", encoding="utf-8") as f:
test_list: List[Dict[str, Any]] = json.load(f)
if args.limit is not None:
test_list = test_list[: args.limit]
print(f"Limiting to first {len(test_list)} examples.")
else:
print(f"Total examples: {len(test_list)}")
print("Loading tokenizer and model:", args.model_id)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
llm = LLM(
model=args.model_id,
trust_remote_code=True,
)
gen_sampling = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_new_tokens,
n=1,
)
critique_sampling = SamplingParams(
temperature=args.critique_temperature,
max_tokens=args.critique_max_new_tokens,
n=1,
)
revise_sampling = SamplingParams(
temperature=args.temperature,
max_tokens=args.revise_max_new_tokens,
n=1,
)
results: List[Dict[str, Any]] = []
total = len(test_list)
for idx, item in enumerate(test_list):
print(f"\n=== Processing example {idx + 1}/{total} (doc_id={item.get('doc_id')}, label={item.get('label')}) ===")
try:
example_result = self_refine_example(
llm=llm,
tokenizer=tokenizer,
item=item,
prompts=prompts,
num_iterations=args.num_iterations,
gen_sampling=gen_sampling,
critique_sampling=critique_sampling,
revise_sampling=revise_sampling,
source_lang=args.src_lang,
)
except Exception as e:
example_result = {
"doc_id": item.get("doc_id"),
"label": item.get("label"),
"error": f"Exception during self-refinement: {e}",
}
results.append(example_result)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if args.output_file:
out_path = args.output_file
base, ext = os.path.splitext(out_path)
if not ext:
out_path = base + ".json"
base = out_path.rsplit(".", 1)[0]
summary_path = base + "_summary.json"
else:
out_path = os.path.join(RESULTS_DIR, f"self_refine_qwen3_{timestamp}.json")
summary_path = os.path.join(
RESULTS_DIR, f"self_refine_qwen3_{timestamp}_summary.json"
)
print("\nSaving detailed self-refinement results to", out_path)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
summary: Dict[str, Any] = {
"model_id": args.model_id,
"prompt_dir": os.path.abspath(args.prompt_dir),
"test_json": os.path.abspath(args.test_json),
"src_lang": args.src_lang,
"num_test_samples": len(test_list),
"results_file": out_path,
"timestamp": timestamp,
"num_iterations": args.num_iterations,
"max_new_tokens": args.max_new_tokens,
"critique_max_new_tokens": args.critique_max_new_tokens,
"revise_max_new_tokens": args.revise_max_new_tokens,
"temperature": args.temperature,
"critique_temperature": args.critique_temperature,
}
print("Saving summary metadata to", summary_path)
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("\nDone.")
if __name__ == "__main__":
main()