import argparse import json import os import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from typing import Any, Dict, List, Optional import pandas as pd import requests from tqdm import tqdm from transformers import AutoTokenizer DEFAULT_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507" DEFAULT_DATASET_PATH = ( "/home/mshahidul/readctrl/code/readctrl_rl_inference/verified_combined_0-80_clean200.json" ) DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result" DEFAULT_BASE_URL = "http://127.0.0.1:8021/v1" DEFAULT_SERVED_MODEL_NAME = "inference" DEFAULT_PROMPT_LOW_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" ) DEFAULT_PROMPT_INTERMEDIATE_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" ) DEFAULT_PROMPT_PROFICIENT_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" ) VALID_LABELS = { "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run batched inference via vLLM OpenAI-compatible server.") parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Local path for tokenizer/chat template.") parser.add_argument("--dataset_path", type=str, default=DEFAULT_DATASET_PATH) parser.add_argument( "--input_name", type=str, default=None, help=( "Optional short name for the input file; used in output filenames. " "If not provided, derived from the basename of --dataset_path." ), ) parser.add_argument( "--output_name", type=str, default=None, help=( "Base name (without extension) for output files. " "If not provided, uses vllm_inference_{model_tag}_{input_name_or_dataset}_{timestamp}." ), ) parser.add_argument("--prompt-low-path", type=str, default=DEFAULT_PROMPT_LOW_PATH) parser.add_argument("--prompt-intermediate-path", type=str, default=DEFAULT_PROMPT_INTERMEDIATE_PATH) parser.add_argument("--prompt-proficient-path", type=str, default=DEFAULT_PROMPT_PROFICIENT_PATH) parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR) parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL, help="vLLM OpenAI base URL, e.g. http://127.0.0.1:8000/v1") parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.") parser.add_argument("--max_tokens", type=int, default=1024) parser.add_argument("--temperature", type=float, default=0.1) parser.add_argument("--top_p", type=float, default=0.8) parser.add_argument("--api_key", type=str, default="EMPTY") parser.add_argument("--timeout_sec", type=int, default=300) parser.add_argument("--num_workers", type=int, default=4, help="Concurrent request threads to keep server pipeline full.") return parser.parse_args() def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: prompt_path_by_label = { "low_health_literacy": args.prompt_low_path, "intermediate_health_literacy": args.prompt_intermediate_path, "proficient_health_literacy": args.prompt_proficient_path, } templates: Dict[str, str] = {} for label, path in prompt_path_by_label.items(): if not os.path.exists(path): raise FileNotFoundError(f"Prompt file not found: {path}") with open(path, "r", encoding="utf-8") as f: templates[label] = f.read() return templates def load_verified_rows(path: str) -> List[Dict[str, Any]]: if not os.path.exists(path): raise FileNotFoundError(f"Input file not found: {path}") with open(path, "r", encoding="utf-8") as f: parsed = json.load(f) if not isinstance(parsed, list): raise ValueError(f"Expected top-level JSON array in {path}") return [row for row in parsed if isinstance(row, dict)] def infer_source_lang(fulltext: str) -> str: if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): return "English" return "Unknown" def split_into_subclaims(text: str, min_chars: int = 15) -> List[str]: """ Lightweight sentence splitter to approximate subclaims from a summary. """ if not text or not text.strip(): return [] parts = re.split(r"(?<=[.!?])\s+", text.strip()) return [s.strip() for s in parts if len(s.strip()) >= min_chars] def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: return ( template.replace("{source_lang}", source_lang) .replace("{gold_summary}", summary) .replace("{full_text}", fulltext) ) def _clean_json_block(text: str) -> str: cleaned = text.strip() if "```json" in cleaned: cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() elif "```" in cleaned: cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() return cleaned def extract_generated_text(raw_response: str, expected_label: str) -> str: cleaned = _clean_json_block(raw_response) try: parsed = json.loads(cleaned) except json.JSONDecodeError: return raw_response.strip() if isinstance(parsed, dict): value = parsed.get(expected_label) if isinstance(value, str) and value.strip(): return value.strip() return raw_response.strip() def _normalize_messages(prompt_obj: Any) -> List[Dict[str, str]]: if hasattr(prompt_obj, "tolist"): prompt_obj = prompt_obj.tolist() if isinstance(prompt_obj, dict): if "role" in prompt_obj and "content" in prompt_obj: return [{"role": str(prompt_obj["role"]), "content": str(prompt_obj["content"])}] return [{"role": "user", "content": json.dumps(prompt_obj, ensure_ascii=False)}] if isinstance(prompt_obj, list): messages = [] for item in prompt_obj: if isinstance(item, dict) and "role" in item and "content" in item: messages.append({"role": str(item["role"]), "content": str(item["content"])}) else: messages.append({"role": "user", "content": str(item)}) if messages: return messages return [{"role": "user", "content": str(prompt_obj)}] def build_prompt_text(tokenizer: AutoTokenizer, prompt_obj: Any) -> str: messages = _normalize_messages(prompt_obj) if tokenizer.chat_template: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return "\n".join(m["content"] for m in messages) + "\n\nAssistant:" def sanitize_model_tag(model_path: str, max_len: int = 80) -> str: tag = re.sub(r"[^A-Za-z0-9]+", "-", model_path).strip("-").lower() if not tag: return "unknown-model" if len(tag) > max_len: return tag[:max_len].rstrip("-") return tag def check_server(base_url: str, headers: Dict[str, str], timeout_sec: int) -> Optional[List[Dict[str, Any]]]: models_url = f"{base_url.rstrip('/')}/models" resp = requests.get(models_url, headers=headers, timeout=timeout_sec) resp.raise_for_status() payload = resp.json() return payload.get("data", []) def batched_completion_request( base_url: str, headers: Dict[str, str], model_name: str, prompts: List[str], max_tokens: int, temperature: float, top_p: float, timeout_sec: int, ) -> List[str]: payload = { "model": model_name, "prompt": prompts, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } url = f"{base_url.rstrip('/')}/completions" resp = requests.post(url, headers=headers, json=payload, timeout=timeout_sec) resp.raise_for_status() data = resp.json() choices = data.get("choices", []) preds = [""] * len(prompts) for choice in choices: idx = choice.get("index", None) text = str(choice.get("text", "")).strip() if isinstance(idx, int) and 0 <= idx < len(preds) and not preds[idx]: preds[idx] = text if any(not p for p in preds): fallback_texts = [str(c.get("text", "")).strip() for c in choices] for i in range(len(preds)): if not preds[i]: preds[i] = fallback_texts[i] if i < len(fallback_texts) else "" return preds def main() -> None: args = parse_args() os.makedirs(args.output_dir, exist_ok=True) run_ts = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = sanitize_model_tag(args.model_path) input_tag_raw = ( args.input_name if args.input_name else os.path.splitext(os.path.basename(args.dataset_path))[0] ) input_tag = sanitize_model_tag(input_tag_raw) default_base = f"vllm_inference_{model_tag}_{input_tag}_{run_ts}" base_name = args.output_name if args.output_name else default_base output_jsonl = os.path.join(args.output_dir, f"{base_name}.jsonl") meta_path = os.path.join(args.output_dir, f"{base_name}_meta.json") headers = { "Authorization": f"Bearer {args.api_key}", "Content-Type": "application/json", } print(f"[INFO] Checking vLLM server: {args.base_url}") models = check_server(args.base_url, headers=headers, timeout_sec=args.timeout_sec) available_model_ids = [m.get("id", "") for m in models or []] print(f"[INFO] Server models: {available_model_ids}") if args.served_model_name not in available_model_ids: print( f"[WARN] Served model '{args.served_model_name}' not found in /models. " "Will still try requests with provided name." ) print(f"[INFO] Loading tokenizer from: {args.model_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) print(f"[INFO] Reading dataset: {args.dataset_path}") templates = load_prompt_templates(args) rows = load_verified_rows(args.dataset_path) parsed_items: List[Dict[str, Any]] = [] for idx, row in enumerate(rows): gold_label = str(row.get("label", "")).strip() fulltext = str(row.get("fulltext", "")).strip() summary = str(row.get("summary", "")).strip() if gold_label not in VALID_LABELS: continue if not fulltext or not summary: continue source_lang = infer_source_lang(fulltext) subclaims = split_into_subclaims(summary) prompt = build_prompt( template=templates[gold_label], fulltext=fulltext, summary=summary, source_lang=source_lang, ) parsed_items.append( { "row_index": idx, "doc_id": row.get("doc_id"), "gold_label": gold_label, "source_lang": source_lang, "summary_text": summary, "input_text": fulltext, "subclaims": subclaims, "prompt": prompt, } ) df = pd.DataFrame(parsed_items) if args.max_samples > 0: df = df.head(args.max_samples) print(f"[INFO] Rows to process: {len(df)}") if df.empty: raise RuntimeError("No valid rows found in input file.") batch_ranges = list(range(0, len(df), args.batch_size)) total_batches = len(batch_ranges) num_workers = min(args.num_workers, total_batches) print(f"[INFO] {total_batches} batches × {args.batch_size} prompts, {num_workers} concurrent workers") t0 = time.time() def _process_batch(start: int) -> List[Dict[str, Any]]: batch_df = df.iloc[start : start + args.batch_size] prompts = [build_prompt_text(tokenizer, row.get("prompt", "")) for _, row in batch_df.iterrows()] preds = batched_completion_request( base_url=args.base_url, headers=headers, model_name=args.served_model_name, prompts=prompts, max_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, timeout_sec=args.timeout_sec, ) records = [] for (row_idx, row), pred in zip(batch_df.iterrows(), preds): gold_label = str(row.get("gold_label", "")) records.append( { "row_index": int(row.get("row_index", row_idx)), "doc_id": row.get("doc_id"), "gold_label": gold_label, "source_lang": row.get("source_lang"), "summary_text": row.get("summary_text", ""), "input_text": row.get("input_text", ""), "subclaims": row.get("subclaims", []), "prediction": pred, "generated_text": extract_generated_text(pred, gold_label) if gold_label else pred.strip(), } ) return records pending_results: Dict[int, List[Dict[str, Any]]] = {} next_write_idx = 0 outputs: List[Dict[str, Any]] = [] with open(output_jsonl, "w", encoding="utf-8") as f_out: with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_idx = { executor.submit(_process_batch, batch_ranges[i]): i for i in range(total_batches) } pbar = tqdm(total=total_batches, desc="Batches") for future in as_completed(future_to_idx): batch_idx = future_to_idx[future] records = future.result() pending_results[batch_idx] = records pbar.update(1) while next_write_idx in pending_results: for rec in pending_results.pop(next_write_idx): outputs.append(rec) f_out.write(json.dumps(rec, ensure_ascii=False) + "\n") next_write_idx += 1 pbar.close() elapsed = time.time() - t0 print(f"[INFO] Inference done: {len(outputs)} samples in {elapsed:.1f}s ({len(outputs)/elapsed:.1f} samples/s)") with open(meta_path, "w", encoding="utf-8") as f_meta: json.dump( { "model_path_for_tokenizer": args.model_path, "dataset_path": args.dataset_path, "input_name": input_tag, "output_name": base_name, "base_url": args.base_url, "served_model_name": args.served_model_name, "batch_size": args.batch_size, "num_samples": len(outputs), "output_jsonl": output_jsonl, }, f_meta, ensure_ascii=False, indent=2, ) print("[DONE] vLLM batch inference complete.") print(f"[DONE] JSONL: {output_jsonl}") print(f"[DONE] Meta: {meta_path}") if __name__ == "__main__": main()