readCtrl_lambda / code /readctrl_rl_inference /run_inference_vllm_server.py
mshahidul
Initial commit of readCtrl code without large models
030876e
raw
history blame
15.3 kB
import argparse
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Any, Dict, List, Optional
import pandas as pd
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
DEFAULT_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"
DEFAULT_DATASET_PATH = (
"/home/mshahidul/readctrl/code/readctrl_rl_inference/verified_combined_0-80_clean200.json"
)
DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result"
DEFAULT_BASE_URL = "http://127.0.0.1:8021/v1"
DEFAULT_SERVED_MODEL_NAME = "inference"
DEFAULT_PROMPT_LOW_PATH = (
"/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low"
)
DEFAULT_PROMPT_INTERMEDIATE_PATH = (
"/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate"
)
DEFAULT_PROMPT_PROFICIENT_PATH = (
"/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient"
)
VALID_LABELS = {
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run batched inference via vLLM OpenAI-compatible server.")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Local path for tokenizer/chat template.")
parser.add_argument("--dataset_path", type=str, default=DEFAULT_DATASET_PATH)
parser.add_argument(
"--input_name",
type=str,
default=None,
help=(
"Optional short name for the input file; used in output filenames. "
"If not provided, derived from the basename of --dataset_path."
),
)
parser.add_argument(
"--output_name",
type=str,
default=None,
help=(
"Base name (without extension) for output files. "
"If not provided, uses vllm_inference_{model_tag}_{input_name_or_dataset}_{timestamp}."
),
)
parser.add_argument("--prompt-low-path", type=str, default=DEFAULT_PROMPT_LOW_PATH)
parser.add_argument("--prompt-intermediate-path", type=str, default=DEFAULT_PROMPT_INTERMEDIATE_PATH)
parser.add_argument("--prompt-proficient-path", type=str, default=DEFAULT_PROMPT_PROFICIENT_PATH)
parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR)
parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL, help="vLLM OpenAI base URL, e.g. http://127.0.0.1:8000/v1")
parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.")
parser.add_argument("--max_tokens", type=int, default=1024)
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--top_p", type=float, default=0.8)
parser.add_argument("--api_key", type=str, default="EMPTY")
parser.add_argument("--timeout_sec", type=int, default=300)
parser.add_argument("--num_workers", type=int, default=4, help="Concurrent request threads to keep server pipeline full.")
return parser.parse_args()
def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]:
prompt_path_by_label = {
"low_health_literacy": args.prompt_low_path,
"intermediate_health_literacy": args.prompt_intermediate_path,
"proficient_health_literacy": args.prompt_proficient_path,
}
templates: Dict[str, str] = {}
for label, path in prompt_path_by_label.items():
if not os.path.exists(path):
raise FileNotFoundError(f"Prompt file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
templates[label] = f.read()
return templates
def load_verified_rows(path: str) -> List[Dict[str, Any]]:
if not os.path.exists(path):
raise FileNotFoundError(f"Input file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
parsed = json.load(f)
if not isinstance(parsed, list):
raise ValueError(f"Expected top-level JSON array in {path}")
return [row for row in parsed if isinstance(row, dict)]
def infer_source_lang(fulltext: str) -> str:
if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext):
return "English"
return "Unknown"
def split_into_subclaims(text: str, min_chars: int = 15) -> List[str]:
"""
Lightweight sentence splitter to approximate subclaims from a summary.
"""
if not text or not text.strip():
return []
parts = re.split(r"(?<=[.!?])\s+", text.strip())
return [s.strip() for s in parts if len(s.strip()) >= min_chars]
def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str:
return (
template.replace("{source_lang}", source_lang)
.replace("{gold_summary}", summary)
.replace("{full_text}", fulltext)
)
def _clean_json_block(text: str) -> str:
cleaned = text.strip()
if "```json" in cleaned:
cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in cleaned:
cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip()
return cleaned
def extract_generated_text(raw_response: str, expected_label: str) -> str:
cleaned = _clean_json_block(raw_response)
try:
parsed = json.loads(cleaned)
except json.JSONDecodeError:
return raw_response.strip()
if isinstance(parsed, dict):
value = parsed.get(expected_label)
if isinstance(value, str) and value.strip():
return value.strip()
return raw_response.strip()
def _normalize_messages(prompt_obj: Any) -> List[Dict[str, str]]:
if hasattr(prompt_obj, "tolist"):
prompt_obj = prompt_obj.tolist()
if isinstance(prompt_obj, dict):
if "role" in prompt_obj and "content" in prompt_obj:
return [{"role": str(prompt_obj["role"]), "content": str(prompt_obj["content"])}]
return [{"role": "user", "content": json.dumps(prompt_obj, ensure_ascii=False)}]
if isinstance(prompt_obj, list):
messages = []
for item in prompt_obj:
if isinstance(item, dict) and "role" in item and "content" in item:
messages.append({"role": str(item["role"]), "content": str(item["content"])})
else:
messages.append({"role": "user", "content": str(item)})
if messages:
return messages
return [{"role": "user", "content": str(prompt_obj)}]
def build_prompt_text(tokenizer: AutoTokenizer, prompt_obj: Any) -> str:
messages = _normalize_messages(prompt_obj)
if tokenizer.chat_template:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return "\n".join(m["content"] for m in messages) + "\n\nAssistant:"
def sanitize_model_tag(model_path: str, max_len: int = 80) -> str:
tag = re.sub(r"[^A-Za-z0-9]+", "-", model_path).strip("-").lower()
if not tag:
return "unknown-model"
if len(tag) > max_len:
return tag[:max_len].rstrip("-")
return tag
def check_server(base_url: str, headers: Dict[str, str], timeout_sec: int) -> Optional[List[Dict[str, Any]]]:
models_url = f"{base_url.rstrip('/')}/models"
resp = requests.get(models_url, headers=headers, timeout=timeout_sec)
resp.raise_for_status()
payload = resp.json()
return payload.get("data", [])
def batched_completion_request(
base_url: str,
headers: Dict[str, str],
model_name: str,
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
timeout_sec: int,
) -> List[str]:
payload = {
"model": model_name,
"prompt": prompts,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
url = f"{base_url.rstrip('/')}/completions"
resp = requests.post(url, headers=headers, json=payload, timeout=timeout_sec)
resp.raise_for_status()
data = resp.json()
choices = data.get("choices", [])
preds = [""] * len(prompts)
for choice in choices:
idx = choice.get("index", None)
text = str(choice.get("text", "")).strip()
if isinstance(idx, int) and 0 <= idx < len(preds) and not preds[idx]:
preds[idx] = text
if any(not p for p in preds):
fallback_texts = [str(c.get("text", "")).strip() for c in choices]
for i in range(len(preds)):
if not preds[i]:
preds[i] = fallback_texts[i] if i < len(fallback_texts) else ""
return preds
def main() -> None:
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = sanitize_model_tag(args.model_path)
input_tag_raw = (
args.input_name
if args.input_name
else os.path.splitext(os.path.basename(args.dataset_path))[0]
)
input_tag = sanitize_model_tag(input_tag_raw)
default_base = f"vllm_inference_{model_tag}_{input_tag}_{run_ts}"
base_name = args.output_name if args.output_name else default_base
output_jsonl = os.path.join(args.output_dir, f"{base_name}.jsonl")
meta_path = os.path.join(args.output_dir, f"{base_name}_meta.json")
headers = {
"Authorization": f"Bearer {args.api_key}",
"Content-Type": "application/json",
}
print(f"[INFO] Checking vLLM server: {args.base_url}")
models = check_server(args.base_url, headers=headers, timeout_sec=args.timeout_sec)
available_model_ids = [m.get("id", "") for m in models or []]
print(f"[INFO] Server models: {available_model_ids}")
if args.served_model_name not in available_model_ids:
print(
f"[WARN] Served model '{args.served_model_name}' not found in /models. "
"Will still try requests with provided name."
)
print(f"[INFO] Loading tokenizer from: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
print(f"[INFO] Reading dataset: {args.dataset_path}")
templates = load_prompt_templates(args)
rows = load_verified_rows(args.dataset_path)
parsed_items: List[Dict[str, Any]] = []
for idx, row in enumerate(rows):
gold_label = str(row.get("label", "")).strip()
fulltext = str(row.get("fulltext", "")).strip()
summary = str(row.get("summary", "")).strip()
if gold_label not in VALID_LABELS:
continue
if not fulltext or not summary:
continue
source_lang = infer_source_lang(fulltext)
subclaims = split_into_subclaims(summary)
prompt = build_prompt(
template=templates[gold_label],
fulltext=fulltext,
summary=summary,
source_lang=source_lang,
)
parsed_items.append(
{
"row_index": idx,
"doc_id": row.get("doc_id"),
"gold_label": gold_label,
"source_lang": source_lang,
"summary_text": summary,
"input_text": fulltext,
"subclaims": subclaims,
"prompt": prompt,
}
)
df = pd.DataFrame(parsed_items)
if args.max_samples > 0:
df = df.head(args.max_samples)
print(f"[INFO] Rows to process: {len(df)}")
if df.empty:
raise RuntimeError("No valid rows found in input file.")
batch_ranges = list(range(0, len(df), args.batch_size))
total_batches = len(batch_ranges)
num_workers = min(args.num_workers, total_batches)
print(f"[INFO] {total_batches} batches × {args.batch_size} prompts, {num_workers} concurrent workers")
t0 = time.time()
def _process_batch(start: int) -> List[Dict[str, Any]]:
batch_df = df.iloc[start : start + args.batch_size]
prompts = [build_prompt_text(tokenizer, row.get("prompt", "")) for _, row in batch_df.iterrows()]
preds = batched_completion_request(
base_url=args.base_url,
headers=headers,
model_name=args.served_model_name,
prompts=prompts,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
timeout_sec=args.timeout_sec,
)
records = []
for (row_idx, row), pred in zip(batch_df.iterrows(), preds):
gold_label = str(row.get("gold_label", ""))
records.append(
{
"row_index": int(row.get("row_index", row_idx)),
"doc_id": row.get("doc_id"),
"gold_label": gold_label,
"source_lang": row.get("source_lang"),
"summary_text": row.get("summary_text", ""),
"input_text": row.get("input_text", ""),
"subclaims": row.get("subclaims", []),
"prediction": pred,
"generated_text": extract_generated_text(pred, gold_label)
if gold_label
else pred.strip(),
}
)
return records
pending_results: Dict[int, List[Dict[str, Any]]] = {}
next_write_idx = 0
outputs: List[Dict[str, Any]] = []
with open(output_jsonl, "w", encoding="utf-8") as f_out:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_idx = {
executor.submit(_process_batch, batch_ranges[i]): i
for i in range(total_batches)
}
pbar = tqdm(total=total_batches, desc="Batches")
for future in as_completed(future_to_idx):
batch_idx = future_to_idx[future]
records = future.result()
pending_results[batch_idx] = records
pbar.update(1)
while next_write_idx in pending_results:
for rec in pending_results.pop(next_write_idx):
outputs.append(rec)
f_out.write(json.dumps(rec, ensure_ascii=False) + "\n")
next_write_idx += 1
pbar.close()
elapsed = time.time() - t0
print(f"[INFO] Inference done: {len(outputs)} samples in {elapsed:.1f}s ({len(outputs)/elapsed:.1f} samples/s)")
with open(meta_path, "w", encoding="utf-8") as f_meta:
json.dump(
{
"model_path_for_tokenizer": args.model_path,
"dataset_path": args.dataset_path,
"input_name": input_tag,
"output_name": base_name,
"base_url": args.base_url,
"served_model_name": args.served_model_name,
"batch_size": args.batch_size,
"num_samples": len(outputs),
"output_jsonl": output_jsonl,
},
f_meta,
ensure_ascii=False,
indent=2,
)
print("[DONE] vLLM batch inference complete.")
print(f"[DONE] JSONL: {output_jsonl}")
print(f"[DONE] Meta: {meta_path}")
if __name__ == "__main__":
main()