| """ |
| Run inference for the finetuned Qwen3 model on test_en.json using vLLM. |
| |
| This script expects that `qwen3-finetune.py` has already been run and the |
| merged model was saved to `/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model`. |
| """ |
|
|
| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "5" |
|
|
| import argparse |
| import json |
| from datetime import datetime |
|
|
| from vllm import LLM, SamplingParams |
| from transformers import AutoTokenizer |
|
|
|
|
| |
| BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" |
| MODEL_DIR = os.path.join(BASE_DIR, "model", "en") |
| PROMPT_DIR = os.path.join(BASE_DIR, "prompt_en") |
| TEST_JSON = os.path.join(BASE_DIR, "dataset", "en", "test_en.json") |
| RESULTS_DIR = os.path.join(BASE_DIR, "results", "en") |
|
|
| SOURCE_LANG = "English" |
| LABEL_TO_PROMPT_FILE = { |
| "low_health_literacy": "prompt_low", |
| "intermediate_health_literacy": "prompt_intermediate", |
| "proficient_health_literacy": "prompt_proficient", |
| } |
|
|
|
|
| def load_prompts(): |
| """Load prompt templates from prompt_en directory.""" |
| prompts = {} |
| for label, filename in LABEL_TO_PROMPT_FILE.items(): |
| path = os.path.join(PROMPT_DIR, filename) |
| if os.path.isfile(path): |
| with open(path, "r", encoding="utf-8") as f: |
| prompts[label] = f.read() |
| else: |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
| return prompts |
|
|
|
|
| def build_user_message(prompt_template, full_text, gold_summary, source_lang=SOURCE_LANG): |
| """Fill prompt template with full_text, gold_summary, source_lang.""" |
| return ( |
| prompt_template.replace("{full_text}", full_text) |
| .replace("{gold_summary}", gold_summary) |
| .replace("{source_lang}", source_lang) |
| ) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Run vLLM inference for health-literacy Qwen3 model on test_en.json." |
| ) |
| p.add_argument( |
| "--model-dir", |
| type=str, |
| default=MODEL_DIR, |
| help="Path to the merged finetuned model directory.", |
| ) |
| p.add_argument( |
| "--max-new-tokens", |
| type=int, |
| default=1024, |
| help="Maximum number of new tokens to generate.", |
| ) |
| p.add_argument( |
| "--temperature", |
| type=float, |
| default=0.0, |
| help="Sampling temperature for generation.", |
| ) |
| p.add_argument( |
| "--batch-size", |
| type=int, |
| default=32, |
| help="Number of samples per vLLM generation call.", |
| ) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| model_dir = args.model_dir |
|
|
| os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
| print("Loading prompts from", PROMPT_DIR) |
| prompts = load_prompts() |
|
|
| print("Loading test data from", TEST_JSON) |
| with open(TEST_JSON, "r", encoding="utf-8") as f: |
| test_list = json.load(f) |
|
|
| print("Loading tokenizer and model from", model_dir) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
| llm = LLM( |
| model=model_dir, |
| trust_remote_code=True, |
| ) |
|
|
| sampling_params = SamplingParams( |
| temperature=args.temperature, |
| max_tokens=args.max_new_tokens, |
| n=1, |
| ) |
|
|
| |
| batched_prompts = [] |
| meta = [] |
| for idx, item in enumerate(test_list): |
| label = item.get("label") |
| doc_id = item.get("doc_id", idx) |
| fulltext = item.get("fulltext", "") |
| summary = item.get("summary", "") |
| gold_gen_text = item.get("gen_text", "") |
|
|
| if label not in prompts: |
| meta.append( |
| { |
| "doc_id": doc_id, |
| "label": label, |
| "gold_gen_text": gold_gen_text, |
| "error": f"Unknown label: {label}", |
| } |
| ) |
| batched_prompts.append(None) |
| continue |
|
|
| user_prompt = build_user_message(prompts[label], fulltext, summary) |
| chat = [{"role": "user", "content": user_prompt}] |
| formatted = tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| batched_prompts.append(formatted) |
| meta.append( |
| { |
| "doc_id": doc_id, |
| "label": label, |
| "gold_gen_text": gold_gen_text, |
| "error": None, |
| } |
| ) |
|
|
| generated_texts = {} |
| |
| valid_indices = [i for i, p in enumerate(batched_prompts) if p is not None] |
| valid_prompts = [batched_prompts[i] for i in valid_indices] |
|
|
| total_valid = len(valid_prompts) |
| batch_size = max(1, args.batch_size) |
| print( |
| f"Running vLLM generation on {total_valid} samples " |
| f"in batches of {batch_size}..." |
| ) |
|
|
| |
| num_batches = (total_valid + batch_size - 1) // batch_size |
| for batch_idx in range(num_batches): |
| start = batch_idx * batch_size |
| end = min(start + batch_size, total_valid) |
| batch_prompts = valid_prompts[start:end] |
| batch_indices = valid_indices[start:end] |
|
|
| print( |
| f"Generating batch {batch_idx + 1}/{num_batches} " |
| f"with {len(batch_prompts)} samples..." |
| ) |
| outputs = llm.generate(batch_prompts, sampling_params=sampling_params) |
|
|
| |
| for idx_in_batch, output in enumerate(outputs): |
| original_idx = batch_indices[idx_in_batch] |
| text = output.outputs[0].text.strip() |
| generated_texts[original_idx] = text |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| results = [] |
|
|
| for idx, info in enumerate(meta): |
| if info["error"] is not None: |
| results.append( |
| { |
| "doc_id": info["doc_id"], |
| "label": info["label"], |
| "gold_gen_text": info["gold_gen_text"], |
| "error": info["error"], |
| } |
| ) |
| else: |
| pred_text = generated_texts.get(idx, "") |
| results.append( |
| { |
| "doc_id": info["doc_id"], |
| "label": info["label"], |
| "gold_gen_text": info["gold_gen_text"], |
| "predicted_gen_text": pred_text, |
| } |
| ) |
|
|
| out_path = os.path.join(RESULTS_DIR, f"test_inference_vllm_{timestamp}.json") |
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| summary_path = os.path.join(RESULTS_DIR, f"inference_summary_vllm_{timestamp}.json") |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump( |
| { |
| "model_dir": model_dir, |
| "test_json": TEST_JSON, |
| "prompt_dir": PROMPT_DIR, |
| "num_test_samples": len(test_list), |
| "results_file": out_path, |
| "timestamp": timestamp, |
| "max_new_tokens": args.max_new_tokens, |
| "temperature": args.temperature, |
| }, |
| f, |
| ensure_ascii=False, |
| indent=2, |
| ) |
|
|
| print(f"Results saved to {out_path}") |
| print(f"Summary saved to {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|