| import argparse |
| import json |
| import os |
| import re |
| import time |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional |
|
|
| import pandas as pd |
| import requests |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
|
|
| DEFAULT_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507" |
| DEFAULT_DATASET_PATH = ( |
| "/home/mshahidul/readctrl/code/readctrl_rl_inference/verified_combined_0-80_clean200.json" |
| ) |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result" |
| DEFAULT_BASE_URL = "http://127.0.0.1:8021/v1" |
| DEFAULT_SERVED_MODEL_NAME = "inference" |
| DEFAULT_PROMPT_LOW_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" |
| ) |
| DEFAULT_PROMPT_INTERMEDIATE_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" |
| ) |
| DEFAULT_PROMPT_PROFICIENT_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" |
| ) |
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run batched inference via vLLM OpenAI-compatible server.") |
| parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Local path for tokenizer/chat template.") |
| parser.add_argument("--dataset_path", type=str, default=DEFAULT_DATASET_PATH) |
| parser.add_argument( |
| "--input_name", |
| type=str, |
| default=None, |
| help=( |
| "Optional short name for the input file; used in output filenames. " |
| "If not provided, derived from the basename of --dataset_path." |
| ), |
| ) |
| parser.add_argument( |
| "--output_name", |
| type=str, |
| default=None, |
| help=( |
| "Base name (without extension) for output files. " |
| "If not provided, uses vllm_inference_{model_tag}_{input_name_or_dataset}_{timestamp}." |
| ), |
| ) |
| parser.add_argument("--prompt-low-path", type=str, default=DEFAULT_PROMPT_LOW_PATH) |
| parser.add_argument("--prompt-intermediate-path", type=str, default=DEFAULT_PROMPT_INTERMEDIATE_PATH) |
| parser.add_argument("--prompt-proficient-path", type=str, default=DEFAULT_PROMPT_PROFICIENT_PATH) |
| parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL, help="vLLM OpenAI base URL, e.g. http://127.0.0.1:8000/v1") |
| parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.") |
| parser.add_argument("--batch_size", type=int, default=64) |
| parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.") |
| parser.add_argument("--max_tokens", type=int, default=1024) |
| parser.add_argument("--temperature", type=float, default=0.1) |
| parser.add_argument("--top_p", type=float, default=0.8) |
| parser.add_argument("--api_key", type=str, default="EMPTY") |
| parser.add_argument("--timeout_sec", type=int, default=300) |
| parser.add_argument("--num_workers", type=int, default=4, help="Concurrent request threads to keep server pipeline full.") |
| return parser.parse_args() |
|
|
|
|
| def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: |
| prompt_path_by_label = { |
| "low_health_literacy": args.prompt_low_path, |
| "intermediate_health_literacy": args.prompt_intermediate_path, |
| "proficient_health_literacy": args.prompt_proficient_path, |
| } |
| templates: Dict[str, str] = {} |
| for label, path in prompt_path_by_label.items(): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| templates[label] = f.read() |
| return templates |
|
|
|
|
| def load_verified_rows(path: str) -> List[Dict[str, Any]]: |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Input file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| parsed = json.load(f) |
| if not isinstance(parsed, list): |
| raise ValueError(f"Expected top-level JSON array in {path}") |
| return [row for row in parsed if isinstance(row, dict)] |
|
|
|
|
| def infer_source_lang(fulltext: str) -> str: |
| if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): |
| return "English" |
| return "Unknown" |
|
|
|
|
| def split_into_subclaims(text: str, min_chars: int = 15) -> List[str]: |
| """ |
| Lightweight sentence splitter to approximate subclaims from a summary. |
| """ |
| if not text or not text.strip(): |
| return [] |
| parts = re.split(r"(?<=[.!?])\s+", text.strip()) |
| return [s.strip() for s in parts if len(s.strip()) >= min_chars] |
|
|
|
|
| def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: |
| return ( |
| template.replace("{source_lang}", source_lang) |
| .replace("{gold_summary}", summary) |
| .replace("{full_text}", fulltext) |
| ) |
|
|
|
|
| def _clean_json_block(text: str) -> str: |
| cleaned = text.strip() |
| if "```json" in cleaned: |
| cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() |
| elif "```" in cleaned: |
| cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() |
| return cleaned |
|
|
|
|
| def extract_generated_text(raw_response: str, expected_label: str) -> str: |
| cleaned = _clean_json_block(raw_response) |
| try: |
| parsed = json.loads(cleaned) |
| except json.JSONDecodeError: |
| return raw_response.strip() |
|
|
| if isinstance(parsed, dict): |
| value = parsed.get(expected_label) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return raw_response.strip() |
|
|
|
|
| def _normalize_messages(prompt_obj: Any) -> List[Dict[str, str]]: |
| if hasattr(prompt_obj, "tolist"): |
| prompt_obj = prompt_obj.tolist() |
|
|
| if isinstance(prompt_obj, dict): |
| if "role" in prompt_obj and "content" in prompt_obj: |
| return [{"role": str(prompt_obj["role"]), "content": str(prompt_obj["content"])}] |
| return [{"role": "user", "content": json.dumps(prompt_obj, ensure_ascii=False)}] |
|
|
| if isinstance(prompt_obj, list): |
| messages = [] |
| for item in prompt_obj: |
| if isinstance(item, dict) and "role" in item and "content" in item: |
| messages.append({"role": str(item["role"]), "content": str(item["content"])}) |
| else: |
| messages.append({"role": "user", "content": str(item)}) |
| if messages: |
| return messages |
|
|
| return [{"role": "user", "content": str(prompt_obj)}] |
|
|
|
|
| def build_prompt_text(tokenizer: AutoTokenizer, prompt_obj: Any) -> str: |
| messages = _normalize_messages(prompt_obj) |
| if tokenizer.chat_template: |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| return "\n".join(m["content"] for m in messages) + "\n\nAssistant:" |
|
|
|
|
| def sanitize_model_tag(model_path: str, max_len: int = 80) -> str: |
| tag = re.sub(r"[^A-Za-z0-9]+", "-", model_path).strip("-").lower() |
| if not tag: |
| return "unknown-model" |
| if len(tag) > max_len: |
| return tag[:max_len].rstrip("-") |
| return tag |
|
|
|
|
| def check_server(base_url: str, headers: Dict[str, str], timeout_sec: int) -> Optional[List[Dict[str, Any]]]: |
| models_url = f"{base_url.rstrip('/')}/models" |
| resp = requests.get(models_url, headers=headers, timeout=timeout_sec) |
| resp.raise_for_status() |
| payload = resp.json() |
| return payload.get("data", []) |
|
|
|
|
| def batched_completion_request( |
| base_url: str, |
| headers: Dict[str, str], |
| model_name: str, |
| prompts: List[str], |
| max_tokens: int, |
| temperature: float, |
| top_p: float, |
| timeout_sec: int, |
| ) -> List[str]: |
| payload = { |
| "model": model_name, |
| "prompt": prompts, |
| "max_tokens": max_tokens, |
| "temperature": temperature, |
| "top_p": top_p, |
| } |
| url = f"{base_url.rstrip('/')}/completions" |
| resp = requests.post(url, headers=headers, json=payload, timeout=timeout_sec) |
| resp.raise_for_status() |
| data = resp.json() |
| choices = data.get("choices", []) |
|
|
| preds = [""] * len(prompts) |
| for choice in choices: |
| idx = choice.get("index", None) |
| text = str(choice.get("text", "")).strip() |
| if isinstance(idx, int) and 0 <= idx < len(preds) and not preds[idx]: |
| preds[idx] = text |
|
|
| if any(not p for p in preds): |
| fallback_texts = [str(c.get("text", "")).strip() for c in choices] |
| for i in range(len(preds)): |
| if not preds[i]: |
| preds[i] = fallback_texts[i] if i < len(fallback_texts) else "" |
|
|
| return preds |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| run_ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_tag = sanitize_model_tag(args.model_path) |
|
|
| input_tag_raw = ( |
| args.input_name |
| if args.input_name |
| else os.path.splitext(os.path.basename(args.dataset_path))[0] |
| ) |
| input_tag = sanitize_model_tag(input_tag_raw) |
| default_base = f"vllm_inference_{model_tag}_{input_tag}_{run_ts}" |
| base_name = args.output_name if args.output_name else default_base |
| output_jsonl = os.path.join(args.output_dir, f"{base_name}.jsonl") |
| meta_path = os.path.join(args.output_dir, f"{base_name}_meta.json") |
|
|
| headers = { |
| "Authorization": f"Bearer {args.api_key}", |
| "Content-Type": "application/json", |
| } |
|
|
| print(f"[INFO] Checking vLLM server: {args.base_url}") |
| models = check_server(args.base_url, headers=headers, timeout_sec=args.timeout_sec) |
| available_model_ids = [m.get("id", "") for m in models or []] |
| print(f"[INFO] Server models: {available_model_ids}") |
| if args.served_model_name not in available_model_ids: |
| print( |
| f"[WARN] Served model '{args.served_model_name}' not found in /models. " |
| "Will still try requests with provided name." |
| ) |
|
|
| print(f"[INFO] Loading tokenizer from: {args.model_path}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) |
|
|
| print(f"[INFO] Reading dataset: {args.dataset_path}") |
| templates = load_prompt_templates(args) |
| rows = load_verified_rows(args.dataset_path) |
| parsed_items: List[Dict[str, Any]] = [] |
| for idx, row in enumerate(rows): |
| gold_label = str(row.get("label", "")).strip() |
| fulltext = str(row.get("fulltext", "")).strip() |
| summary = str(row.get("summary", "")).strip() |
| if gold_label not in VALID_LABELS: |
| continue |
| if not fulltext or not summary: |
| continue |
| source_lang = infer_source_lang(fulltext) |
| subclaims = split_into_subclaims(summary) |
| prompt = build_prompt( |
| template=templates[gold_label], |
| fulltext=fulltext, |
| summary=summary, |
| source_lang=source_lang, |
| ) |
| parsed_items.append( |
| { |
| "row_index": idx, |
| "doc_id": row.get("doc_id"), |
| "gold_label": gold_label, |
| "source_lang": source_lang, |
| "summary_text": summary, |
| "input_text": fulltext, |
| "subclaims": subclaims, |
| "prompt": prompt, |
| } |
| ) |
|
|
| df = pd.DataFrame(parsed_items) |
| if args.max_samples > 0: |
| df = df.head(args.max_samples) |
| print(f"[INFO] Rows to process: {len(df)}") |
| if df.empty: |
| raise RuntimeError("No valid rows found in input file.") |
|
|
| batch_ranges = list(range(0, len(df), args.batch_size)) |
| total_batches = len(batch_ranges) |
| num_workers = min(args.num_workers, total_batches) |
| print(f"[INFO] {total_batches} batches × {args.batch_size} prompts, {num_workers} concurrent workers") |
|
|
| t0 = time.time() |
|
|
| def _process_batch(start: int) -> List[Dict[str, Any]]: |
| batch_df = df.iloc[start : start + args.batch_size] |
| prompts = [build_prompt_text(tokenizer, row.get("prompt", "")) for _, row in batch_df.iterrows()] |
| preds = batched_completion_request( |
| base_url=args.base_url, |
| headers=headers, |
| model_name=args.served_model_name, |
| prompts=prompts, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| timeout_sec=args.timeout_sec, |
| ) |
| records = [] |
| for (row_idx, row), pred in zip(batch_df.iterrows(), preds): |
| gold_label = str(row.get("gold_label", "")) |
| records.append( |
| { |
| "row_index": int(row.get("row_index", row_idx)), |
| "doc_id": row.get("doc_id"), |
| "gold_label": gold_label, |
| "source_lang": row.get("source_lang"), |
| "summary_text": row.get("summary_text", ""), |
| "input_text": row.get("input_text", ""), |
| "subclaims": row.get("subclaims", []), |
| "prediction": pred, |
| "generated_text": extract_generated_text(pred, gold_label) |
| if gold_label |
| else pred.strip(), |
| } |
| ) |
| return records |
|
|
| pending_results: Dict[int, List[Dict[str, Any]]] = {} |
| next_write_idx = 0 |
| outputs: List[Dict[str, Any]] = [] |
|
|
| with open(output_jsonl, "w", encoding="utf-8") as f_out: |
| with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| future_to_idx = { |
| executor.submit(_process_batch, batch_ranges[i]): i |
| for i in range(total_batches) |
| } |
| pbar = tqdm(total=total_batches, desc="Batches") |
| for future in as_completed(future_to_idx): |
| batch_idx = future_to_idx[future] |
| records = future.result() |
| pending_results[batch_idx] = records |
| pbar.update(1) |
|
|
| while next_write_idx in pending_results: |
| for rec in pending_results.pop(next_write_idx): |
| outputs.append(rec) |
| f_out.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| next_write_idx += 1 |
| pbar.close() |
|
|
| elapsed = time.time() - t0 |
| print(f"[INFO] Inference done: {len(outputs)} samples in {elapsed:.1f}s ({len(outputs)/elapsed:.1f} samples/s)") |
|
|
| with open(meta_path, "w", encoding="utf-8") as f_meta: |
| json.dump( |
| { |
| "model_path_for_tokenizer": args.model_path, |
| "dataset_path": args.dataset_path, |
| "input_name": input_tag, |
| "output_name": base_name, |
| "base_url": args.base_url, |
| "served_model_name": args.served_model_name, |
| "batch_size": args.batch_size, |
| "num_samples": len(outputs), |
| "output_jsonl": output_jsonl, |
| }, |
| f_meta, |
| ensure_ascii=False, |
| indent=2, |
| ) |
|
|
| print("[DONE] vLLM batch inference complete.") |
| print(f"[DONE] JSONL: {output_jsonl}") |
| print(f"[DONE] Meta: {meta_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|