readCtrl_lambda / code /fine_tune_sft_dpo /qwen3-inference-vllm.py
mshahidul
Initial commit of readCtrl code without large models
030876e
"""
Run inference for the finetuned Qwen3 model on test_en.json using vLLM.
This script expects that `qwen3-finetune.py` has already been run and the
merged model was saved to `/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model`.
"""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import argparse
import json
from datetime import datetime
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
# Paths
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
MODEL_DIR = os.path.join(BASE_DIR, "model", "en")
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_en")
TEST_JSON = os.path.join(BASE_DIR, "dataset", "en", "test_en.json")
RESULTS_DIR = os.path.join(BASE_DIR, "results", "en")
SOURCE_LANG = "English"
LABEL_TO_PROMPT_FILE = {
"low_health_literacy": "prompt_low",
"intermediate_health_literacy": "prompt_intermediate",
"proficient_health_literacy": "prompt_proficient",
}
def load_prompts():
"""Load prompt templates from prompt_en directory."""
prompts = {}
for label, filename in LABEL_TO_PROMPT_FILE.items():
path = os.path.join(PROMPT_DIR, filename)
if os.path.isfile(path):
with open(path, "r", encoding="utf-8") as f:
prompts[label] = f.read()
else:
raise FileNotFoundError(f"Prompt file not found: {path}")
return prompts
def build_user_message(prompt_template, full_text, gold_summary, source_lang=SOURCE_LANG):
"""Fill prompt template with full_text, gold_summary, source_lang."""
return (
prompt_template.replace("{full_text}", full_text)
.replace("{gold_summary}", gold_summary)
.replace("{source_lang}", source_lang)
)
def parse_args():
p = argparse.ArgumentParser(
description="Run vLLM inference for health-literacy Qwen3 model on test_en.json."
)
p.add_argument(
"--model-dir",
type=str,
default=MODEL_DIR,
help="Path to the merged finetuned model directory.",
)
p.add_argument(
"--max-new-tokens",
type=int,
default=1024,
help="Maximum number of new tokens to generate.",
)
p.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature for generation.",
)
p.add_argument(
"--batch-size",
type=int,
default=32,
help="Number of samples per vLLM generation call.",
)
return p.parse_args()
def main():
args = parse_args()
model_dir = args.model_dir
os.makedirs(RESULTS_DIR, exist_ok=True)
print("Loading prompts from", PROMPT_DIR)
prompts = load_prompts()
print("Loading test data from", TEST_JSON)
with open(TEST_JSON, "r", encoding="utf-8") as f:
test_list = json.load(f)
print("Loading tokenizer and model from", model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
llm = LLM(
model=model_dir,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_new_tokens,
n=1,
)
# Build prompts in the same way as training/inference before, via chat template.
batched_prompts = []
meta = []
for idx, item in enumerate(test_list):
label = item.get("label")
doc_id = item.get("doc_id", idx)
fulltext = item.get("fulltext", "")
summary = item.get("summary", "")
gold_gen_text = item.get("gen_text", "")
if label not in prompts:
meta.append(
{
"doc_id": doc_id,
"label": label,
"gold_gen_text": gold_gen_text,
"error": f"Unknown label: {label}",
}
)
batched_prompts.append(None)
continue
user_prompt = build_user_message(prompts[label], fulltext, summary)
chat = [{"role": "user", "content": user_prompt}]
formatted = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
batched_prompts.append(formatted)
meta.append(
{
"doc_id": doc_id,
"label": label,
"gold_gen_text": gold_gen_text,
"error": None,
}
)
generated_texts = {}
# Filter out None prompts (unknown labels) for generation
valid_indices = [i for i, p in enumerate(batched_prompts) if p is not None]
valid_prompts = [batched_prompts[i] for i in valid_indices]
total_valid = len(valid_prompts)
batch_size = max(1, args.batch_size)
print(
f"Running vLLM generation on {total_valid} samples "
f"in batches of {batch_size}..."
)
# Run batched generation to avoid overloading memory or GPU
num_batches = (total_valid + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start = batch_idx * batch_size
end = min(start + batch_size, total_valid)
batch_prompts = valid_prompts[start:end]
batch_indices = valid_indices[start:end]
print(
f"Generating batch {batch_idx + 1}/{num_batches} "
f"with {len(batch_prompts)} samples..."
)
outputs = llm.generate(batch_prompts, sampling_params=sampling_params)
# Map generation results for this batch back to global indices
for idx_in_batch, output in enumerate(outputs):
original_idx = batch_indices[idx_in_batch]
text = output.outputs[0].text.strip()
generated_texts[original_idx] = text
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = []
for idx, info in enumerate(meta):
if info["error"] is not None:
results.append(
{
"doc_id": info["doc_id"],
"label": info["label"],
"gold_gen_text": info["gold_gen_text"],
"error": info["error"],
}
)
else:
pred_text = generated_texts.get(idx, "")
results.append(
{
"doc_id": info["doc_id"],
"label": info["label"],
"gold_gen_text": info["gold_gen_text"],
"predicted_gen_text": pred_text,
}
)
out_path = os.path.join(RESULTS_DIR, f"test_inference_vllm_{timestamp}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
summary_path = os.path.join(RESULTS_DIR, f"inference_summary_vllm_{timestamp}.json")
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(
{
"model_dir": model_dir,
"test_json": TEST_JSON,
"prompt_dir": PROMPT_DIR,
"num_test_samples": len(test_list),
"results_file": out_path,
"timestamp": timestamp,
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
},
f,
ensure_ascii=False,
indent=2,
)
print(f"Results saved to {out_path}")
print(f"Summary saved to {summary_path}")
if __name__ == "__main__":
main()